From 7d8cf838ea9391bbc74845e9880adfae2647de9c Mon Sep 17 00:00:00 2001 From: David Koski Date: Wed, 26 Feb 2025 23:23:03 -0800 Subject: [PATCH] fix for #214 -- Qwen 2.5 fails to load: Key weight not found in Linear - Qwen2 (LLM) had slightly incorrect logic in the initialization regarding lm_head - it was initialized even if not used, but this causes parameter loading to fail with current 0.21.3 mlx-swift --- Libraries/MLXLLM/Models/Qwen2.swift | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/Libraries/MLXLLM/Models/Qwen2.swift b/Libraries/MLXLLM/Models/Qwen2.swift index 06dd2e9..567098a 100644 --- a/Libraries/MLXLLM/Models/Qwen2.swift +++ b/Libraries/MLXLLM/Models/Qwen2.swift @@ -172,22 +172,25 @@ public class Qwen2Model: Module, LLMModel, KVCacheDimensionProvider { private let model: Qwen2ModelInner let configuration: Qwen2Configuration - @ModuleInfo(key: "lm_head") var lmHead: Linear + @ModuleInfo(key: "lm_head") var lmHead: Linear? public init(_ args: Qwen2Configuration) { self.configuration = args self.vocabularySize = args.vocabularySize self.kvHeads = (0 ..< args.hiddenLayers).map { _ in args.kvHeads } self.model = Qwen2ModelInner(args) - _lmHead.wrappedValue = Linear(args.hiddenSize, args.vocabularySize, bias: false) + + if !args.tieWordEmbeddings { + _lmHead.wrappedValue = Linear(args.hiddenSize, args.vocabularySize, bias: false) + } } public func callAsFunction(_ inputs: MLXArray, cache: [KVCache]?) -> MLXArray { var out = model(inputs, cache: cache) - if configuration.tieWordEmbeddings { - out = model.embedTokens.asLinear(out) - } else { + if let lmHead { out = lmHead(out) + } else { + out = model.embedTokens.asLinear(out) } return out }