Skip to content

Commit 18c29ff

Browse files
author
黄宇扬
committedJul 23, 2024·
增加自定义模型文档
1 parent cb9ff7a commit 18c29ff

File tree

3 files changed

+324
-1
lines changed

3 files changed

+324
-1
lines changed
 

‎README.md

+4-1
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ fastllm是纯c++实现,无第三方依赖的多平台高性能大模型推理
2020
- 🚀 支持动态Batch,流式输出
2121
- 🚀 前后端分离设计,便于支持新的计算设备
2222
- 🚀 目前支持ChatGLM系列模型,Qwen系列模型,各种LLAMA模型(ALPACA, VICUNA等),BAICHUAN模型,MOSS模型,MINICPM模型等
23+
- 🚀 支持Python自定义模型结构
2324

2425
## 快速开始
2526

@@ -69,12 +70,14 @@ python3 -m ftllm.webui -t 16 -p ~/Qwen2-7B-Instruct/ --port 8080
6970

7071
一些早期的HuggingFace模型无法直接读取,可以参考 [模型转换](docs/models.md#模型导出convert-offline) 转换fastllm格式的模型
7172

73+
可以自定义模型结构,具体见 [自定义模型](docs/custom_model.md)
74+
7275
### 运行demo程序 (c++)
7376

7477
```
7578
# 进入fastllm/build-fastllm目录
7679
77-
# 命令行聊天程序, 支持打字机效果 (只支持Linux)
80+
# 命令行聊天程序, 支持打字机效果
7881
./main -p ~/Qwen2-7B-Instruct/
7982
8083
# 简易webui, 使用流式输出 + 动态batch,可多路并发访问

‎docs/custom.md

+108
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
对于Fastllm框架中没有支持的模型,可以通过自定义模型结构来支持
2+
3+
Pyhton 自定义模型只需要一个python文件来描述模型结构,可参考 [QWEN](../example/python/qwen2.py) 中的实现
4+
5+
### Python自定义模型的使用
6+
7+
使用ftllm.chat, ftllm.webui, ftllm.server时,可以加入参数--custom来指定自定义模型文件
8+
9+
假设我们的模型位于 "~/Qwen2-7B-Instruct/" 目录,自定义模型位于 "~/qwen2.py"
10+
11+
那么可以使用命令
12+
13+
``` sh
14+
python3 -m ftllm.chat -t 16 -p ~/Qwen2-7B-Instruct/ --custom ~/qwen2.py
15+
```
16+
17+
来通过自定义模型文件加在Qwen2模型,server和webui用法类似
18+
19+
### Python自定义模型的写法
20+
21+
自定义模型时,需要实现一个模型的描述类,继承自ftllm.llm.ComputeGraph
22+
23+
对应 [QWEN](../example/python/qwen2.py) 中的代码
24+
25+
``` python
26+
from ftllm.llm import ComputeGraph
27+
class Qwen2Model(ComputeGraph):
28+
```
29+
30+
文件最后需要定义 `__model__` 变量来指定自定义模型结构对应的class, 对应代码
31+
32+
``` python
33+
__model__ = Qwen2Model
34+
```
35+
36+
模型描述类中需要实现build方法,来获取模型参数、描述计算流程
37+
38+
这里以示例代码为例介绍
39+
40+
``` python
41+
class Qwen2Model(ComputeGraph):
42+
def build(self):
43+
# 1. 获取weight, data, config
44+
weight, data, config = self.weight, self.data, self.config
45+
46+
# 2. 设置一些config
47+
config["max_positions"] = 128000
48+
49+
# 3. 描述计算流程
50+
head_dim = config["hidden_size"] // config["num_attention_heads"]
51+
self.Embedding(data["inputIds"], weight["model.embed_tokens.weight"], data["hiddenStates"]);
52+
# 以下是计算流程,具体参见示例代码
53+
```
54+
55+
#### `self.config`
56+
57+
模型配置,默认会从模型文件夹下的 `config.json` 文件中读取
58+
59+
build方法中可以修改config中的参数,例如改动 `max_positions` 可以修改上下文长度
60+
61+
有一些模型的 `config.json` 中使用的变量名不一致,需要在build过程中手动为config赋值。
62+
63+
例如在TeleChat7B模型的配置中没有 `max_positions` 变量,而是用 `seq_length` 变量代表长度,那么在build方法中需要用如下代码赋值:
64+
65+
``` python
66+
self.config["max_positions"] = self.config["seq_length"]
67+
```
68+
69+
config中,有以下变量必须要赋值(如果config.json中变量名一致,可以不处理):
70+
71+
``` python
72+
self.config["max_positions"] #代表最长上下文长度
73+
```
74+
75+
#### `self.weight`
76+
77+
代表权重数据
78+
79+
`self.weight[weightName]` 代表模型文件中名为weightName的参数(对应HF模型文件夹中.safetensors文件中的参数名)
80+
81+
#### ```self.data```
82+
83+
代表计算流程的中间变量和输入变量
84+
85+
`self.data[dataName]` 代表名为dataName的中间变量,`dataName` 可以使用除以下输入变量名之外的任意字符串
86+
87+
输入变量:
88+
89+
``` python
90+
data["inputIds"] # 输入token
91+
data["positionIds"] # 位置信息
92+
data["attentionMask"] # mask信息
93+
data["sin"] # 用于旋转编码的sin
94+
data["cos"] # 用于旋转编码的cos
95+
data["atype"] # 推理中的数据类型
96+
data["pastKey."][i] # 第i个block的key cache
97+
data["pastValue."][i] # 第i个block的value cache
98+
```
99+
100+
#### 计算流程及算子
101+
102+
使用基类ComputeGraph添加算子的函数来描述计算流程
103+
104+
目前支持的算子见文档 [自定义模型算子](./custom_op.md)
105+
106+
### cpp版本的自定义模型
107+
108+
(cpp版本的自定义模型接口还在修改中...)

‎docs/custom_op.md

+212
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,212 @@
1+
## 自定义模型算子文档
2+
3+
### `AddTo`
4+
```python
5+
def AddTo(self, input0, input1, alpha = 1.0):
6+
"""
7+
将两个输入节点相加,并乘以一个可选的缩放因子 alpha。
8+
9+
参数:
10+
input0 (GraphNode): 第一个输入节点。
11+
input1 (GraphNode): 第二个输入节点。
12+
alpha (float, optional): 缩放因子,默认为 1.0。
13+
14+
返回:
15+
无返回值,结果存储在内部图结构中。
16+
"""
17+
self.graph.append({"type": "AddTo",
18+
"nodes": {"input0": input0, "input1": input1, "alpha": FloatGraphNode(alpha)}})
19+
```
20+
21+
### `DataTypeAs`
22+
```python
23+
def DataTypeAs(self, input, input1):
24+
"""
25+
将输入节点的数据类型转换为另一个输入节点的数据类型。
26+
27+
参数:
28+
input (GraphNode): 需要转换数据类型的输入节点。
29+
input1 (GraphNode): 目标数据类型的输入节点。
30+
31+
返回:
32+
无返回值,结果存储在内部图结构中。
33+
"""
34+
self.graph.append({"type": "DataTypeAs",
35+
"nodes": {"input": input, "input1": input1}})
36+
```
37+
38+
### `Embedding`
39+
```python
40+
def Embedding(self, input, weight, output):
41+
"""
42+
执行嵌入操作,将输入索引映射到嵌入权重。
43+
44+
参数:
45+
input (GraphNode): 输入索引节点。
46+
weight (GraphNode): 嵌入权重节点。
47+
output (GraphNode): 输出节点。
48+
49+
返回:
50+
无返回值,结果存储在内部图结构中。
51+
"""
52+
self.graph.append({"type": "Embedding",
53+
"nodes": {"input": input, "weight": weight, "output": output}})
54+
```
55+
56+
### `ExpandHead`
57+
```python
58+
def ExpandHead(self, input, headDim):
59+
"""
60+
把input最后一维展开成[-1, headDim]。
61+
62+
参数:
63+
input (GraphNode): 输入节点。
64+
headDim (int): 头部维度大小。
65+
66+
返回:
67+
无返回值,结果存储在内部图结构中。
68+
"""
69+
self.graph.append({"type": "ExpandHeads",
70+
"nodes": {"input": input, "headDim": IntGraphNode(headDim)}})
71+
```
72+
73+
### `FusedAttention`
74+
```python
75+
def FusedAttention(self, q, k, v, curk, curv, original, mask, output, seqLens,
76+
scale, maskType=0, unitLen=128):
77+
"""
78+
执行Attention操作。
79+
80+
参数:
81+
q (GraphNode): 查询节点。
82+
k (GraphNode): key cache
83+
v (GraphNode): value cache
84+
curk (GraphNode): 当前key
85+
curv (GraphNode): 当前value
86+
original (GraphNode): 原始节点,用于恢复计算后的shape
87+
mask (GraphNode): 掩码
88+
output (GraphNode): 输出
89+
seqLens (GraphNode): 序列长度
90+
scale (float): 缩放因子
91+
maskType (int, optional): 掩码类型,默认为 0。
92+
unitLen (int, optional): 单元长度,默认为 128。
93+
94+
返回:
95+
无返回值,结果存储在内部图结构中。
96+
"""
97+
self.graph.append({"type": "FusedAttention",
98+
"nodes": {"q": q, "k": k, "v": v, "curk": curk, "curv": curv,
99+
"original": original, "mask": mask, "output": output, "seqLens": seqLens,
100+
"scale": FloatGraphNode(scale),
101+
"maskType": IntGraphNode(maskType), "unitLen": IntGraphNode(unitLen)}})
102+
```
103+
104+
### `Linear`
105+
```python
106+
def Linear(self, input, weight, bias, output):
107+
"""
108+
执行线性变换操作。
109+
110+
参数:
111+
input (GraphNode): 输入节点。
112+
weight (GraphNode): 权重节点。
113+
bias (GraphNode): 偏置节点。
114+
output (GraphNode): 输出节点。
115+
116+
返回:
117+
无返回值,结果存储在内部图结构中。
118+
"""
119+
self.graph.append({"type": "Linear",
120+
"nodes": {"input": input, "weight": weight, "bias": bias, "output": output}})
121+
```
122+
123+
### `LlamaRotatePosition2D`
124+
```python
125+
def LlamaRotatePosition2D(self, input, positionIds, sin, cos, rotaryDim):
126+
"""
127+
执行 Llama 模型的二维位置旋转操作。
128+
129+
参数:
130+
input (GraphNode): 输入节点。
131+
positionIds (GraphNode): 位置 ID 节点。
132+
sin (GraphNode): 正弦节点。
133+
cos (GraphNode): 余弦节点。
134+
rotaryDim (int): 旋转维度大小。
135+
136+
返回:
137+
无返回值,结果存储在内部图结构中。
138+
"""
139+
self.graph.append({"type": "LlamaRotatePosition2D",
140+
"nodes": {"input": input, "positionIds": positionIds, "sin": sin, "cos": cos, "rotaryDim": IntGraphNode(rotaryDim)}})
141+
```
142+
143+
### `MulTo`
144+
```python
145+
def MulTo(self, input0, input1):
146+
"""
147+
将两个输入节点相乘。
148+
149+
参数:
150+
input0 (GraphNode): 第一个输入节点。
151+
input1 (GraphNode): 第二个输入节点。
152+
153+
返回:
154+
无返回值,结果存储在内部图结构中。
155+
"""
156+
self.graph.append({"type": "MulTo",
157+
"nodes": {"input0": input0, "input1": input1}})
158+
```
159+
160+
### `RMSNorm`
161+
```python
162+
def RMSNorm(self, input, weight, eps, output):
163+
"""
164+
执行 RMS 归一化操作。
165+
166+
参数:
167+
input (GraphNode): 输入节点。
168+
weight (GraphNode): 权重节点。
169+
eps (float): 小常数,用于防止除零错误。
170+
output (GraphNode): 输出节点。
171+
172+
返回:
173+
无返回值,结果存储在内部图结构中。
174+
"""
175+
self.graph.append({"type": "RMSNorm",
176+
"nodes": {"input": input, "weight": weight, "eps": FloatGraphNode(eps), "output": output}})
177+
```
178+
179+
### `Silu`
180+
```python
181+
def Silu(self, input, output):
182+
"""
183+
执行 SiLU(Sigmoid Linear Unit)激活函数操作。
184+
185+
参数:
186+
input (GraphNode): 输入节点。
187+
output (GraphNode): 输出节点。
188+
189+
返回:
190+
无返回值,结果存储在内部图结构中。
191+
"""
192+
self.graph.append({"type": "Silu",
193+
"nodes": {"input": input, "output": output}})
194+
```
195+
196+
### `SplitLastTokenStates`
197+
```python
198+
def SplitLastTokenStates(self, input, seqLens, output):
199+
"""
200+
分割batch输入中每个batch的最后一个 token 状态。
201+
202+
参数:
203+
input (GraphNode): 输入节点。
204+
seqLens (GraphNode): 序列长度节点。
205+
output (GraphNode): 输出节点。
206+
207+
返回:
208+
无返回值,结果存储在内部图结构中。
209+
"""
210+
self.graph.append({"type": "SplitLastTokenStates",
211+
"nodes": {"input": input, "output": output, "seqLens": seqLens}})
212+
```

0 commit comments

Comments
 (0)