Skip to content

Commit

Permalink
Merge branch 'main' into temporarily-disable-new-tensor-validation
Browse files Browse the repository at this point in the history
  • Loading branch information
atdrendel authored Feb 26, 2025
2 parents 647eea6 + dda2c2b commit f3d9037
Show file tree
Hide file tree
Showing 45 changed files with 279 additions and 268 deletions.
14 changes: 14 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
name: CI

on: push

jobs:
test:
runs-on: macos-15

steps:
- uses: actions/checkout@v4
- name: Select Xcode 16
run: sudo xcode-select -s /Applications/Xcode_16.2.app
- name: Test
run: swift test
7 changes: 5 additions & 2 deletions Package.swift
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,11 @@ let package = Package(
moduleAliases: ["Models": "TransformersModels"]
),
],
resources: [
.copy("Resources"),
// resources: [
// .copy("Resources/gemma-2-2b-it-4bit"),
// ],
linkerSettings: [
.linkedFramework("CoreGraphics", .when(platforms: [.macOS])),
]
),
.testTarget(
Expand Down
91 changes: 54 additions & 37 deletions Sources/SHLLM/Bundle+SHLLM.swift
Original file line number Diff line number Diff line change
Expand Up @@ -5,45 +5,62 @@ extension Bundle {
// `Bundle.module` does not work correctly in tests. There is a thread on the
// Swift Forums describing the issue:
// https://forums.swift.org/t/swift-5-3-spm-resources-in-tests-uses-wrong-bundle-path/37051
static var shllm: Bundle = {
let bundleName = "SHLLM_SHLLM"

var candidates = [
// Bundle should be present here when the package is linked into an App.
Bundle.main.resourceURL,

// Bundle should be present here when the package is linked into a framework.
Bundle(for: BundleLocator.self).resourceURL,

// For command-line tools.
Bundle.main.bundleURL,

// iOS Xcode previews.
Bundle(for: BundleLocator.self)
.resourceURL?
.deletingLastPathComponent()
.deletingLastPathComponent()
.deletingLastPathComponent(),

// macOS Xcode previews.
Bundle(for: BundleLocator.self)
.resourceURL?
.deletingLastPathComponent()
.deletingLastPathComponent(),
]

// For tests
// https://forums.swift.org/t/swift-5-3-spm-resources-in-tests-uses-wrong-bundle-path/37051/21
candidates += Bundle.allBundles.compactMap(\.resourceURL)

for candidate in candidates {
let bundlePath = candidate?.appendingPathComponent(bundleName + ".bundle")
if let bundle = bundlePath.flatMap(Bundle.init(url:)) {
return bundle
static var shllm: Bundle {
get throws {
let bundleName = "SHLLM_SHLLM"

var candidates = [
// Bundle should be present here when the package is linked into an App.
Bundle.main.resourceURL,

// Bundle should be present here when the package is linked into a framework.
Bundle(for: BundleLocator.self).resourceURL,

// For command-line tools.
Bundle.main.bundleURL,

// iOS Xcode previews.
Bundle(for: BundleLocator.self)
.resourceURL?
.deletingLastPathComponent()
.deletingLastPathComponent()
.deletingLastPathComponent(),

// macOS Xcode previews.
Bundle(for: BundleLocator.self)
.resourceURL?
.deletingLastPathComponent()
.deletingLastPathComponent(),
]

// For tests
// https://forums.swift.org/t/swift-5-3-spm-resources-in-tests-uses-wrong-bundle-path/37051/21
candidates += Bundle.allBundles.compactMap(\.resourceURL)

for candidate in candidates {
let bundlePath = candidate?.appendingPathComponent(bundleName + ".bundle")
if let bundle = bundlePath.flatMap(Bundle.init(url:)) {
return bundle
}
}

throw SHLLMError.missingBundle(bundleName)
}
}

func directory(named name: String) throws -> URL {
if let url = url(forResource: name, withExtension: nil) {
return url
} else if let url = url(
forResource: name,
withExtension: nil,
subdirectory: "Resources"
) {
return url
} else {
throw SHLLMError.directoryNotFound(name)
}
fatalError("unable to find bundle named '\(bundleName)'")
}()
}
}

private class BundleLocator: NSObject {}
17 changes: 17 additions & 0 deletions Sources/SHLLM/LLM.swift
Original file line number Diff line number Diff line change
@@ -1,12 +1,29 @@
import CoreGraphics
import Foundation
import struct Hub.Config
import Metal
import MLX
import MLXLLM
import MLXNN
import MLXLMCommon
import Tokenizers

public final class LLM {
public static var isSupportedDevice: Bool {
guard let _ = MTLCreateSystemDefaultDevice() else {
return false
}
return true
}

static var assertSupportedDevice: Void {
get throws {
guard isSupportedDevice else {
throw SHLLMError.unsupportedDevice
}
}
}

private let directory: URL
private let context: ModelContext
private let configuration: ModelConfiguration
Expand Down
10 changes: 2 additions & 8 deletions Sources/SHLLM/Models/CodeLlama.swift
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ public actor CodeLlama: ModelProtocol {
public let llm: AsyncLockedValue<LLM>

public init(directory: URL) async throws {
try LLM.assertSupportedDevice
let llm = try await LLM.llama(directory: directory)
self.llm = .init(llm)
}
Expand All @@ -13,14 +14,7 @@ extension CodeLlama {
static var bundleDirectory: URL {
get throws {
let dir = "CodeLlama-13b-Instruct-hf-4bit-MLX"
guard let url = Bundle.shllm.url(
forResource: dir,
withExtension: nil,
subdirectory: "Resources"
) else {
throw SHLLMError.directoryNotFound(dir)
}
return url
return try Bundle.shllm.directory(named: dir)
}
}
}
10 changes: 2 additions & 8 deletions Sources/SHLLM/Models/DeepSeekR1.swift
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ public actor DeepSeekR1: ModelProtocol {
public let llm: AsyncLockedValue<LLM>

public init(directory: URL) async throws {
try LLM.assertSupportedDevice
let llm = try await LLM.qwen2(directory: directory)
self.llm = .init(llm)
}
Expand All @@ -13,14 +14,7 @@ extension DeepSeekR1 {
static var bundleDirectory: URL {
get throws {
let dir = "DeepSeek-R1-Distill-Qwen-7B-4bit"
guard let url = Bundle.shllm.url(
forResource: dir,
withExtension: nil,
subdirectory: "Resources"
) else {
throw SHLLMError.directoryNotFound(dir)
}
return url
return try Bundle.shllm.directory(named: dir)
}
}
}
10 changes: 2 additions & 8 deletions Sources/SHLLM/Models/Gemma.swift
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ public actor Gemma: ModelProtocol {
public let llm: AsyncLockedValue<LLM>

public init(directory: URL) async throws {
try LLM.assertSupportedDevice
let llm = try await LLM.gemma(directory: directory)
self.llm = .init(llm)
}
Expand All @@ -13,14 +14,7 @@ extension Gemma {
static var bundleDirectory: URL {
get throws {
let dir = "quantized-gemma-2b-it"
guard let url = Bundle.shllm.url(
forResource: dir,
withExtension: nil,
subdirectory: "Resources"
) else {
throw SHLLMError.directoryNotFound(dir)
}
return url
return try Bundle.shllm.directory(named: dir)
}
}
}
10 changes: 2 additions & 8 deletions Sources/SHLLM/Models/Gemma2-2B.swift
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ public actor Gemma2_2B: ModelProtocol {
public let llm: AsyncLockedValue<LLM>

public init(directory: URL) async throws {
try LLM.assertSupportedDevice
let llm = try await LLM.gemma2(directory: directory)
self.llm = .init(llm)
}
Expand All @@ -13,14 +14,7 @@ extension Gemma2_2B {
static var bundleDirectory: URL {
get throws {
let dir = "gemma-2-2b-it-4bit"
guard let url = Bundle.shllm.url(
forResource: dir,
withExtension: nil,
subdirectory: "Resources"
) else {
throw SHLLMError.directoryNotFound(dir)
}
return url
return try Bundle.shllm.directory(named: dir)
}
}
}
10 changes: 2 additions & 8 deletions Sources/SHLLM/Models/Gemma2-9B.swift
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ public actor Gemma2_9B: ModelProtocol {
public let llm: AsyncLockedValue<LLM>

public init(directory: URL) async throws {
try LLM.assertSupportedDevice
let llm = try await LLM.gemma2(directory: directory)
self.llm = .init(llm)
}
Expand All @@ -13,14 +14,7 @@ extension Gemma2_9B {
static var bundleDirectory: URL {
get throws {
let dir = "gemma-2-9b-it-4bit"
guard let url = Bundle.shllm.url(
forResource: dir,
withExtension: nil,
subdirectory: "Resources"
) else {
throw SHLLMError.directoryNotFound(dir)
}
return url
return try Bundle.shllm.directory(named: dir)
}
}
}
10 changes: 2 additions & 8 deletions Sources/SHLLM/Models/Llama3-8B.swift
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ public actor Llama3_8B: ModelProtocol {
public let llm: AsyncLockedValue<LLM>

public init(directory: URL) async throws {
try LLM.assertSupportedDevice
let llm = try await LLM.llama(directory: directory)
self.llm = .init(llm)
}
Expand All @@ -13,14 +14,7 @@ extension Llama3_8B {
static var bundleDirectory: URL {
get throws {
let dir = "Meta-Llama-3-8B-Instruct-4bit"
guard let url = Bundle.shllm.url(
forResource: dir,
withExtension: nil,
subdirectory: "Resources"
) else {
throw SHLLMError.directoryNotFound(dir)
}
return url
return try Bundle.shllm.directory(named: dir)
}
}
}
10 changes: 2 additions & 8 deletions Sources/SHLLM/Models/Llama3_1-8B.swift
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ public actor Llama3_1__8B: ModelProtocol {
public let llm: AsyncLockedValue<LLM>

public init(directory: URL) async throws {
try LLM.assertSupportedDevice
let llm = try await LLM.llama(directory: directory)
self.llm = .init(llm)
}
Expand All @@ -13,14 +14,7 @@ extension Llama3_1__8B {
static var bundleDirectory: URL {
get throws {
let dir = "Meta-Llama-3.1-8B-Instruct-4bit"
guard let url = Bundle.shllm.url(
forResource: dir,
withExtension: nil,
subdirectory: "Resources"
) else {
throw SHLLMError.directoryNotFound(dir)
}
return url
return try Bundle.shllm.directory(named: dir)
}
}
}
10 changes: 2 additions & 8 deletions Sources/SHLLM/Models/Llama3_2-1B.swift
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ public actor Llama3_2__1B: ModelProtocol {
public let llm: AsyncLockedValue<LLM>

public init(directory: URL) async throws {
try LLM.assertSupportedDevice
let llm = try await LLM.llama(directory: directory)
self.llm = .init(llm)
}
Expand All @@ -13,14 +14,7 @@ extension Llama3_2__1B {
static var bundleDirectory: URL {
get throws {
let dir = "Llama-3.2-1B-Instruct-4bit"
guard let url = Bundle.shllm.url(
forResource: dir,
withExtension: nil,
subdirectory: "Resources"
) else {
throw SHLLMError.directoryNotFound(dir)
}
return url
return try Bundle.shllm.directory(named: dir)
}
}
}
10 changes: 2 additions & 8 deletions Sources/SHLLM/Models/Llama3_2-3B.swift
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ public actor Llama3_2__3B: ModelProtocol {
public let llm: AsyncLockedValue<LLM>

public init(directory: URL) async throws {
try LLM.assertSupportedDevice
let llm = try await LLM.llama(directory: directory)
self.llm = .init(llm)
}
Expand All @@ -13,14 +14,7 @@ extension Llama3_2__3B {
static var bundleDirectory: URL {
get throws {
let dir = "Llama-3.2-3B-Instruct-4bit"
guard let url = Bundle.shllm.url(
forResource: dir,
withExtension: nil,
subdirectory: "Resources"
) else {
throw SHLLMError.directoryNotFound(dir)
}
return url
return try Bundle.shllm.directory(named: dir)
}
}
}
10 changes: 2 additions & 8 deletions Sources/SHLLM/Models/Mistral7B.swift
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ public actor Mistral7B: ModelProtocol {
public let llm: AsyncLockedValue<LLM>

public init(directory: URL) async throws {
try LLM.assertSupportedDevice
let llm = try await LLM.llama(directory: directory)
self.llm = .init(llm)
}
Expand All @@ -13,14 +14,7 @@ extension Mistral7B {
static var bundleDirectory: URL {
get throws {
let dir = "Mistral-7B-Instruct-v0.3-4bit"
guard let url = Bundle.shllm.url(
forResource: dir,
withExtension: nil,
subdirectory: "Resources"
) else {
throw SHLLMError.directoryNotFound(dir)
}
return url
return try Bundle.shllm.directory(named: dir)
}
}
}
Loading

0 comments on commit f3d9037

Please sign in to comment.