Skip to content

Commit cb9ff7a

Browse files
author
黄宇扬
committed
python自定义模型提供文件导入的功能
1 parent 506fccf commit cb9ff7a

File tree

5 files changed

+91
-42
lines changed

5 files changed

+91
-42
lines changed

example/python/custom_model.py

+1-36
Original file line numberDiff line numberDiff line change
@@ -1,45 +1,10 @@
11
from ftllm import llm
2-
from ftllm.llm import ComputeGraph
2+
from qwen2 import Qwen2Model
33
import os
4-
import math
54

65
root_path = "/mnt/hfmodels/"
76
model_path = os.path.join(root_path, "Qwen/Qwen2-7B-Instruct")
87

9-
class Qwen2Model(ComputeGraph):
10-
def build(self):
11-
weight, data, config = self.weight, self.data, self.config
12-
head_dim = config["hidden_size"] // config["num_attention_heads"]
13-
self.Embedding(data["inputIds"], weight["model.embed_tokens.weight"], data["hiddenStates"]);
14-
self.DataTypeAs(data["hiddenStates"], data["atype"])
15-
for i in range(config["num_hidden_layers"]):
16-
pastKey = data["pastKey."][i]
17-
pastValue = data["pastValue."][i]
18-
layer = weight["model.layers."][i]
19-
self.RMSNorm(data["hiddenStates"], layer[".input_layernorm.weight"], config["rms_norm_eps"], data["attenInput"])
20-
self.Linear(data["attenInput"], layer[".self_attn.q_proj.weight"], layer[".self_attn.q_proj.bias"], data["q"])
21-
self.Linear(data["attenInput"], layer[".self_attn.k_proj.weight"], layer[".self_attn.k_proj.bias"], data["k"])
22-
self.Linear(data["attenInput"], layer[".self_attn.v_proj.weight"], layer[".self_attn.v_proj.bias"], data["v"])
23-
self.ExpandHead(data["q"], head_dim)
24-
self.ExpandHead(data["k"], head_dim)
25-
self.ExpandHead(data["v"], head_dim)
26-
self.LlamaRotatePosition2D(data["q"], data["positionIds"], data["sin"], data["cos"], head_dim // 2)
27-
self.LlamaRotatePosition2D(data["k"], data["positionIds"], data["sin"], data["cos"], head_dim // 2)
28-
self.FusedAttention(data["q"], pastKey, pastValue, data["k"], data["v"], data["attenInput"],
29-
data["attentionMask"], data["attenOutput"], data["seqLens"], 1.0 / math.sqrt(head_dim))
30-
self.Linear(data["attenOutput"], layer[".self_attn.o_proj.weight"], layer[".self_attn.o_proj.bias"], data["attenLastOutput"]);
31-
self.AddTo(data["hiddenStates"], data["attenLastOutput"]);
32-
self.RMSNorm(data["hiddenStates"], layer[".post_attention_layernorm.weight"], config["rms_norm_eps"], data["attenInput"])
33-
self.Linear(data["attenInput"], layer[".mlp.gate_proj.weight"], layer[".mlp.gate_proj.bias"], data["w1"])
34-
self.Linear(data["attenInput"], layer[".mlp.up_proj.weight"], layer[".mlp.up_proj.bias"], data["w3"])
35-
self.Silu(data["w1"], data["w1"])
36-
self.MulTo(data["w1"], data["w3"])
37-
self.Linear(data["w1"], layer[".mlp.down_proj.weight"], layer[".mlp.down_proj.bias"], data["w2"])
38-
self.AddTo(data["hiddenStates"], data["w2"])
39-
self.SplitLastTokenStates(data["hiddenStates"], data["seqLens"], data["lastTokensStates"])
40-
self.RMSNorm(data["lastTokensStates"], weight["model.norm.weight"], config["rms_norm_eps"], data["lastTokensStates"])
41-
self.Linear(data["lastTokensStates"], weight["lm_head.weight"], weight["lm_head.bias"], data["logits"])
42-
438
model = llm.model(model_path, graph = Qwen2Model)
449
prompt = "北京有什么景点?"
4510
messages = [

example/python/qwen2.py

+40
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
from ftllm.llm import ComputeGraph
2+
import math
3+
4+
class Qwen2Model(ComputeGraph):
5+
def build(self):
6+
weight, data, config = self.weight, self.data, self.config
7+
config["max_positions"] = 128000
8+
9+
head_dim = config["hidden_size"] // config["num_attention_heads"]
10+
self.Embedding(data["inputIds"], weight["model.embed_tokens.weight"], data["hiddenStates"]);
11+
self.DataTypeAs(data["hiddenStates"], data["atype"])
12+
for i in range(config["num_hidden_layers"]):
13+
pastKey = data["pastKey."][i]
14+
pastValue = data["pastValue."][i]
15+
layer = weight["model.layers."][i]
16+
self.RMSNorm(data["hiddenStates"], layer[".input_layernorm.weight"], config["rms_norm_eps"], data["attenInput"])
17+
self.Linear(data["attenInput"], layer[".self_attn.q_proj.weight"], layer[".self_attn.q_proj.bias"], data["q"])
18+
self.Linear(data["attenInput"], layer[".self_attn.k_proj.weight"], layer[".self_attn.k_proj.bias"], data["k"])
19+
self.Linear(data["attenInput"], layer[".self_attn.v_proj.weight"], layer[".self_attn.v_proj.bias"], data["v"])
20+
self.ExpandHead(data["q"], head_dim)
21+
self.ExpandHead(data["k"], head_dim)
22+
self.ExpandHead(data["v"], head_dim)
23+
self.LlamaRotatePosition2D(data["q"], data["positionIds"], data["sin"], data["cos"], head_dim // 2)
24+
self.LlamaRotatePosition2D(data["k"], data["positionIds"], data["sin"], data["cos"], head_dim // 2)
25+
self.FusedAttention(data["q"], pastKey, pastValue, data["k"], data["v"], data["attenInput"],
26+
data["attentionMask"], data["attenOutput"], data["seqLens"], 1.0 / math.sqrt(head_dim))
27+
self.Linear(data["attenOutput"], layer[".self_attn.o_proj.weight"], layer[".self_attn.o_proj.bias"], data["attenLastOutput"]);
28+
self.AddTo(data["hiddenStates"], data["attenLastOutput"]);
29+
self.RMSNorm(data["hiddenStates"], layer[".post_attention_layernorm.weight"], config["rms_norm_eps"], data["attenInput"])
30+
self.Linear(data["attenInput"], layer[".mlp.gate_proj.weight"], layer[".mlp.gate_proj.bias"], data["w1"])
31+
self.Linear(data["attenInput"], layer[".mlp.up_proj.weight"], layer[".mlp.up_proj.bias"], data["w3"])
32+
self.Silu(data["w1"], data["w1"])
33+
self.MulTo(data["w1"], data["w3"])
34+
self.Linear(data["w1"], layer[".mlp.down_proj.weight"], layer[".mlp.down_proj.bias"], data["w2"])
35+
self.AddTo(data["hiddenStates"], data["w2"])
36+
self.SplitLastTokenStates(data["hiddenStates"], data["seqLens"], data["lastTokensStates"])
37+
self.RMSNorm(data["lastTokensStates"], weight["model.norm.weight"], config["rms_norm_eps"], data["lastTokensStates"])
38+
self.Linear(data["lastTokensStates"], weight["lm_head.weight"], weight["lm_head.bias"], data["logits"])
39+
40+
__model__ = Qwen2Model

src/models/graph/fastllmjson.cpp

+30-4
Original file line numberDiff line numberDiff line change
@@ -3,21 +3,47 @@
33
namespace fastllm {
44
class FastllmJsonModelConfig : GraphLLMModelConfig {
55
public:
6-
json11::Json config;
6+
json11::Json json, graphJson, configJson, tokenizerConfigJson, generationConfigJson;
77

88
void Init(const std::string &configString) {
99
std::string error;
10-
config = json11::Json::parse(configString, error);
10+
json = json11::Json::parse(configString, error);
11+
graphJson = json["graph"];
12+
configJson = json["config"];
13+
tokenizerConfigJson = json["tokenizer_config"];
14+
generationConfigJson = json["generation_config"];
1115
}
1216

1317
void InitParams(GraphLLMModel *model) {
18+
if (configJson["max_positions"].is_number()) {
19+
model->max_positions = configJson["max_positions"].int_value();
20+
}
21+
if (configJson["rope_base"].is_number()) {
22+
model->rope_base = configJson["rope_base"].number_value();
23+
}
24+
if (configJson["rope_factor"].is_number()) {
25+
model->rope_factor = configJson["rope_factor"].number_value();
26+
}
27+
28+
if (configJson["pre_prompt"].is_string()) {
29+
model->pre_prompt = configJson["pre_prompt"].string_value();
30+
}
31+
if (configJson["user_role"].is_string()) {
32+
model->user_role = configJson["user_role"].string_value();
33+
}
34+
if (configJson["bot_role"].is_string()) {
35+
model->bot_role = configJson["bot_role"].string_value();
36+
}
37+
if (configJson["history_sep"].is_string()) {
38+
model->history_sep = configJson["history_sep"].string_value();
39+
}
1440
}
1541

1642
std::map <std::string, std::vector <std::pair <std::string, DataType> > >
1743
GetTensorMap(GraphLLMModel *model, const std::vector <std::string> &tensorNames) {
1844
std::string embeddingName = "";
1945
std::map <std::string, std::vector <std::pair <std::string, DataType> > > ret;
20-
for (auto &op : config.array_items()) {
46+
for (auto &op : graphJson.array_items()) {
2147
std::string type = op["type"].string_value();
2248
std::map <std::string, std::string> weights;
2349
for (auto &it : op["nodes"].object_items()) {
@@ -54,7 +80,7 @@ namespace fastllm {
5480
wNodes[it.first] = ComputeGraphNode(it.first);
5581
}
5682

57-
for (auto &op : config.array_items()) {
83+
for (auto &op : graphJson.array_items()) {
5884
std::string type = op["type"].string_value();
5985
std::map <std::string, std::string> datas;
6086
std::map <std::string, float> floatParams;

tools/fastllm_pytools/llm.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -1020,7 +1020,14 @@ def __init__ (self):
10201020
self.graph = []
10211021

10221022
def __str__(self):
1023-
return json.dumps(self.graph, indent = 4, default = lambda x: x.to_json())
1023+
output = {"graph": self.graph}
1024+
if (hasattr(self, "config")):
1025+
output["config"] = self.config
1026+
if (hasattr(self, "tokenizer_config")):
1027+
output["tokenizer_config"] = self.tokenizer_config
1028+
if (hasattr(self, "generation_config")):
1029+
output["generation_config"] = self.generation_config
1030+
return json.dumps(output, indent = 4, default = lambda x: x.to_json())
10241031

10251032
def Print(self, input):
10261033
self.graph.append({"type": "Print",

tools/fastllm_pytools/util.py

+12-1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ def make_normal_parser(des: str) -> argparse.ArgumentParser:
1212
parser.add_argument('--kv_cache_limit', type = str, default = "auto", help = 'kv缓存最大使用量')
1313
parser.add_argument('--max_batch', type = int, default = -1, help = '每次最多同时推理的询问数量')
1414
parser.add_argument('--device', type = str, help = '使用的设备')
15+
parser.add_argument('--custom', type = str, default = "", help = '指定描述自定义模型的python文件')
1516
return parser
1617

1718
def make_normal_llm_model(args):
@@ -29,7 +30,17 @@ def make_normal_llm_model(args):
2930
llm.set_cpu_low_mem(args.low)
3031
if (args.cuda_embedding):
3132
llm.set_cuda_embedding(True)
32-
model = llm.model(args.path, dtype = args.dtype, tokenizer_type = "auto")
33+
graph = None
34+
if (args.custom != ""):
35+
import importlib.util
36+
spec = importlib.util.spec_from_file_location("custom_module", args.custom)
37+
if spec is None:
38+
raise ImportError(f"Cannot load module at {args.custom}")
39+
custom_module = importlib.util.module_from_spec(spec)
40+
spec.loader.exec_module(custom_module)
41+
if (hasattr(custom_module, "__model__")):
42+
graph = getattr(custom_module, "__model__")
43+
model = llm.model(args.path, dtype = args.dtype, graph = graph, tokenizer_type = "auto")
3344
model.set_atype(args.atype)
3445
if (args.max_batch > 0):
3546
model.set_max_batch(args.max_batch)

0 commit comments

Comments
 (0)