Skip to content

Commit

Permalink
Clean up tests
Browse files Browse the repository at this point in the history
  • Loading branch information
atdrendel committed Feb 26, 2025
1 parent b3b7b2f commit f91ea6f
Show file tree
Hide file tree
Showing 7 changed files with 258 additions and 130 deletions.
2 changes: 2 additions & 0 deletions Package.swift
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,9 @@ let package = Package(
),
],
// resources: [
// .copy("Resources/DeepSeek-R1-Distill-Qwen-7B-4bit"),
// .copy("Resources/gemma-2-2b-it-4bit"),
// .copy("Resources/Phi-3.5-mini-instruct-4bit"),
// .copy("Resources/Qwen2.5-1.5B-Instruct-4bit"),
// ],
linkerSettings: [
Expand Down
10 changes: 6 additions & 4 deletions Sources/SHLLM/LLM.swift
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import struct Hub.Config
import Metal
import MLX
import MLXLLM
import MLXNN
import MLXLMCommon
import MLXNN
import Tokenizers
Expand Down Expand Up @@ -199,7 +198,10 @@ extension LLM {
messages: [Message],
maxTokenCount: Int = 1024 * 1024
) async throws -> T {
let result = try await request(.init(messages: messages, tools: tools.toSpec()), maxTokenCount: maxTokenCount)
let result = try await request(
.init(messages: messages, tools: tools.toSpec()),
maxTokenCount: maxTokenCount
)

let decoder = JSONDecoder()

Expand All @@ -213,7 +215,7 @@ extension LLM {
messages: [Message],
maxTokenCount: Int = 1024 * 1024
) async throws -> String {
return try await request(.init(messages: messages), maxTokenCount: maxTokenCount)
try await request(.init(messages: messages), maxTokenCount: maxTokenCount)
}

func request(
Expand Down Expand Up @@ -281,7 +283,7 @@ private extension String {
let prefix = "<tool_call>\n"
let suffix = "\n</tool_call>"

var copy = self.trimmingCharacters(in: .whitespacesAndNewlines)
var copy = trimmingCharacters(in: .whitespacesAndNewlines)
copy.removeFirst(prefix.count)
copy.removeLast(suffix.count)
return copy
Expand Down
6 changes: 5 additions & 1 deletion Sources/SHLLM/ModelProtocol.swift
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,11 @@ public extension ModelProtocol {
maxTokenCount: Int = 1024 * 1024
) async throws -> T {
try await llm.withLock { llm in
try await llm.request(tools: tools, messages: messages, maxTokenCount: maxTokenCount)
try await llm.request(
tools: tools,
messages: messages,
maxTokenCount: maxTokenCount
)
}
}

Expand Down
191 changes: 148 additions & 43 deletions Sources/SHLLM/Tools.swift
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ public struct Tools {
}

public func toSpec() -> [[String: any Sendable]] {
return functions.map { $0.toSpec() }
functions.map { $0.toSpec() }
}
}

Expand All @@ -34,16 +34,16 @@ public struct ToolFunction {

var functionSpec: [String: any Sendable] = [
"name": name,
"parameters": propertiesSpec
"parameters": propertiesSpec,
]

if let description = description {
if let description {
functionSpec["description"] = description
}

return [
"type": "function",
"function": functionSpec
"function": functionSpec,
]
}
}
Expand All @@ -54,35 +54,124 @@ public struct ToolFunctionParameter {
public let description: String?
public let required: Bool

static func string(name: String, description: String? = nil, required: Bool = false, format: ToolFunctionStringFormat = .plain, minLength: Int? = nil, maxLength: Int? = nil, restrictTo: [String]? = nil, asConst: String? = nil) -> ToolFunctionParameter {
.init(name: name, type: .string(format: format, minLength: minLength, maxLength: maxLength, restrictTo: restrictTo, asConst: asConst), description: description, required: required)
static func string(
name: String,
description: String? = nil,
required: Bool = false,
format: ToolFunctionStringFormat = .plain,
minLength: Int? = nil,
maxLength: Int? = nil,
restrictTo: [String]? = nil,
asConst: String? = nil
) -> ToolFunctionParameter {
.init(
name: name,
type: .string(
format: format,
minLength: minLength,
maxLength: maxLength,
restrictTo: restrictTo,
asConst: asConst
),
description: description,
required: required
)
}

static func number(name: String, description: String? = nil, required: Bool = false, minimum: Double? = nil, maximum: Double? = nil, asConst: Double? = nil, multipleOf: Double? = nil) -> ToolFunctionParameter {
.init(name: name, type: .number(minimum: minimum, maximum: maximum, asConst: asConst, multipleOf: multipleOf), description: description, required: required)
static func number(
name: String,
description: String? = nil,
required: Bool = false,
minimum: Double? = nil,
maximum: Double? = nil,
asConst: Double? = nil,
multipleOf: Double? = nil
) -> ToolFunctionParameter {
.init(
name: name,
type: .number(
minimum: minimum,
maximum: maximum,
asConst: asConst,
multipleOf: multipleOf
),
description: description,
required: required
)
}

static func integer(name: String, description: String? = nil, required: Bool = false, minimum: Int? = nil, maximum: Int? = nil, asConst: Int? = nil, multipleOf: Int? = nil) -> ToolFunctionParameter {
.init(name: name, type: .integer(minimum: minimum, maximum: maximum, asConst: asConst, multipleOf: multipleOf), description: description, required: required)
static func integer(
name: String,
description: String? = nil,
required: Bool = false,
minimum: Int? = nil,
maximum: Int? = nil,
asConst: Int? = nil,
multipleOf: Int? = nil
) -> ToolFunctionParameter {
.init(
name: name,
type: .integer(
minimum: minimum,
maximum: maximum,
asConst: asConst,
multipleOf: multipleOf
),
description: description,
required: required
)
}

static func array(name: String, description: String? = nil, required: Bool = false, items: ToolFunctionParameterType? = nil) -> ToolFunctionParameter {
.init(name: name, type: .array(items: items), description: description, required: required)
static func array(
name: String,
description: String? = nil,
required: Bool = false,
items: ToolFunctionParameterType? = nil
) -> ToolFunctionParameter {
.init(
name: name,
type: .array(items: items),
description: description,
required: required
)
}

static func object(name: String, description: String? = nil, required: Bool = false, properties: [ToolFunctionParameter] = []) -> ToolFunctionParameter {
.init(name: name, type: .object(properties: properties), description: description, required: required)
static func object(
name: String,
description: String? = nil,
required: Bool = false,
properties: [ToolFunctionParameter] = []
) -> ToolFunctionParameter {
.init(
name: name,
type: .object(properties: properties),
description: description,
required: required
)
}

static func boolean(name: String, description: String? = nil, required: Bool = false) -> ToolFunctionParameter {
static func boolean(
name: String,
description: String? = nil,
required: Bool = false
) -> ToolFunctionParameter {
.init(name: name, type: .boolean, description: description, required: required)
}

static func null(name: String, description: String? = nil, required: Bool = false) -> ToolFunctionParameter {
static func null(
name: String,
description: String? = nil,
required: Bool = false
) -> ToolFunctionParameter {
.init(name: name, type: .null, description: description, required: required)
}

init(name: String, type: ToolFunctionParameterType, description: String? = nil, required: Bool = false) {
init(
name: String,
type: ToolFunctionParameterType,
description: String? = nil,
required: Bool = false
) {
self.name = name
self.type = type
self.description = description
Expand All @@ -94,10 +183,10 @@ public struct ToolFunctionParameter {

var dict: [String: any Sendable] = [
"name": name,
"type": typeSpec["type"]
"type": typeSpec["type"],
]

if let description = description {
if let description {
dict["description"] = description
}

Expand All @@ -111,25 +200,41 @@ public struct ToolFunctionParameter {
public indirect enum ToolFunctionParameterType {
case array(items: ToolFunctionParameterType? = nil)
case object(properties: [ToolFunctionParameter] = [])
case number(minimum: Double? = nil, maximum: Double? = nil, asConst: Double? = nil, multipleOf: Double? = nil)
case integer(minimum: Int? = nil, maximum: Int? = nil, asConst: Int? = nil, multipleOf: Int? = nil)
case string(format: ToolFunctionStringFormat = .plain, minLength: Int? = nil, maxLength: Int? = nil, restrictTo: [String]? = nil, asConst: String? = nil)
case number(
minimum: Double? = nil,
maximum: Double? = nil,
asConst: Double? = nil,
multipleOf: Double? = nil
)
case integer(
minimum: Int? = nil,
maximum: Int? = nil,
asConst: Int? = nil,
multipleOf: Int? = nil
)
case string(
format: ToolFunctionStringFormat = .plain,
minLength: Int? = nil,
maxLength: Int? = nil,
restrictTo: [String]? = nil,
asConst: String? = nil
)
case boolean
case null

public func toSpec() -> [String: any Sendable] {
switch self {
case .array(let items):
case let .array(items):
var dict: [String: Any] = ["type": "array"]
if let items = items {
if let items {
dict["items"] = items.toSpec()
}
return dict

case .object(let properties):
case let .object(properties):
var _properties: [String: any Sendable] = [:]

properties.forEach { prop in
for prop in properties {
var dict = prop.toSpec()
let name = prop.name
dict.removeValue(forKey: "name")
Expand All @@ -141,32 +246,32 @@ public indirect enum ToolFunctionParameterType {
"properties": _properties,
]

case .number(let minimum, let maximum, let asConst, let multipleOf):
case let .number(minimum, maximum, asConst, multipleOf):
var dict: [String: Any] = ["type": "number"]
if let minimum = minimum { dict["minimum"] = minimum }
if let maximum = maximum { dict["maximum"] = maximum }
if let asConst = asConst { dict["const"] = asConst }
if let multipleOf = multipleOf { dict["multipleOf"] = multipleOf }
if let minimum { dict["minimum"] = minimum }
if let maximum { dict["maximum"] = maximum }
if let asConst { dict["const"] = asConst }
if let multipleOf { dict["multipleOf"] = multipleOf }
return dict

case .integer(let minimum, let maximum, let asConst, let multipleOf):
case let .integer(minimum, maximum, asConst, multipleOf):
var dict: [String: Any] = ["type": "integer"]
if let minimum = minimum { dict["minimum"] = minimum }
if let maximum = maximum { dict["maximum"] = maximum }
if let asConst = asConst { dict["const"] = asConst }
if let multipleOf = multipleOf { dict["multipleOf"] = multipleOf }
if let minimum { dict["minimum"] = minimum }
if let maximum { dict["maximum"] = maximum }
if let asConst { dict["const"] = asConst }
if let multipleOf { dict["multipleOf"] = multipleOf }
return dict

case .string(let format, let minLength, let maxLength, let restrictTo, let asConst):
case let .string(format, minLength, maxLength, restrictTo, asConst):
var dict: [String: Any] = ["type": "string"]

if format != .plain {
dict["format"] = format.rawValue
}
if let minLength = minLength { dict["minLength"] = minLength }
if let maxLength = maxLength { dict["maxLength"] = maxLength }
if let restrictTo = restrictTo { dict["enum"] = restrictTo }
if let asConst = asConst { dict["const"] = asConst }
if let minLength { dict["minLength"] = minLength }
if let maxLength { dict["maxLength"] = maxLength }
if let restrictTo { dict["enum"] = restrictTo }
if let asConst { dict["const"] = asConst }
return dict

case .boolean:
Expand All @@ -179,6 +284,6 @@ public indirect enum ToolFunctionParameterType {
}

public enum ToolFunctionStringFormat: String {
case plain = "plain"
case date = "date"
case plain
case date
}
Loading

0 comments on commit f91ea6f

Please sign in to comment.