Skip to content

Commit

Permalink
fix for #214 -- Qwen 2.5 fails to load: Key weight not found in Linear
Browse files Browse the repository at this point in the history
- Qwen2 (LLM) had slightly incorrect logic in the initialization regarding lm_head
- it was initialized even if not used, but this causes parameter loading to fail with current 0.21.3 mlx-swift
  • Loading branch information
davidkoski committed Feb 27, 2025
1 parent b23bbf0 commit 7d8cf83
Showing 1 changed file with 8 additions and 5 deletions.
13 changes: 8 additions & 5 deletions Libraries/MLXLLM/Models/Qwen2.swift
Original file line number Diff line number Diff line change
Expand Up @@ -172,22 +172,25 @@ public class Qwen2Model: Module, LLMModel, KVCacheDimensionProvider {
private let model: Qwen2ModelInner
let configuration: Qwen2Configuration

@ModuleInfo(key: "lm_head") var lmHead: Linear
@ModuleInfo(key: "lm_head") var lmHead: Linear?

public init(_ args: Qwen2Configuration) {
self.configuration = args
self.vocabularySize = args.vocabularySize
self.kvHeads = (0 ..< args.hiddenLayers).map { _ in args.kvHeads }
self.model = Qwen2ModelInner(args)
_lmHead.wrappedValue = Linear(args.hiddenSize, args.vocabularySize, bias: false)

if !args.tieWordEmbeddings {
_lmHead.wrappedValue = Linear(args.hiddenSize, args.vocabularySize, bias: false)
}
}

public func callAsFunction(_ inputs: MLXArray, cache: [KVCache]?) -> MLXArray {
var out = model(inputs, cache: cache)
if configuration.tieWordEmbeddings {
out = model.embedTokens.asLinear(out)
} else {
if let lmHead {
out = lmHead(out)
} else {
out = model.embedTokens.asLinear(out)
}
return out
}
Expand Down

0 comments on commit 7d8cf83

Please sign in to comment.