|
| 1 | +## 自定义模型算子文档 |
| 2 | + |
| 3 | +### `AddTo` |
| 4 | +```python |
| 5 | +def AddTo(self, input0, input1, alpha = 1.0): |
| 6 | + """ |
| 7 | + 将两个输入节点相加,并乘以一个可选的缩放因子 alpha。 |
| 8 | +
|
| 9 | + 参数: |
| 10 | + input0 (GraphNode): 第一个输入节点。 |
| 11 | + input1 (GraphNode): 第二个输入节点。 |
| 12 | + alpha (float, optional): 缩放因子,默认为 1.0。 |
| 13 | +
|
| 14 | + 返回: |
| 15 | + 无返回值,结果存储在内部图结构中。 |
| 16 | + """ |
| 17 | + self.graph.append({"type": "AddTo", |
| 18 | + "nodes": {"input0": input0, "input1": input1, "alpha": FloatGraphNode(alpha)}}) |
| 19 | +``` |
| 20 | + |
| 21 | +### `DataTypeAs` |
| 22 | +```python |
| 23 | +def DataTypeAs(self, input, input1): |
| 24 | + """ |
| 25 | + 将输入节点的数据类型转换为另一个输入节点的数据类型。 |
| 26 | +
|
| 27 | + 参数: |
| 28 | + input (GraphNode): 需要转换数据类型的输入节点。 |
| 29 | + input1 (GraphNode): 目标数据类型的输入节点。 |
| 30 | +
|
| 31 | + 返回: |
| 32 | + 无返回值,结果存储在内部图结构中。 |
| 33 | + """ |
| 34 | + self.graph.append({"type": "DataTypeAs", |
| 35 | + "nodes": {"input": input, "input1": input1}}) |
| 36 | +``` |
| 37 | + |
| 38 | +### `Embedding` |
| 39 | +```python |
| 40 | +def Embedding(self, input, weight, output): |
| 41 | + """ |
| 42 | + 执行嵌入操作,将输入索引映射到嵌入权重。 |
| 43 | +
|
| 44 | + 参数: |
| 45 | + input (GraphNode): 输入索引节点。 |
| 46 | + weight (GraphNode): 嵌入权重节点。 |
| 47 | + output (GraphNode): 输出节点。 |
| 48 | +
|
| 49 | + 返回: |
| 50 | + 无返回值,结果存储在内部图结构中。 |
| 51 | + """ |
| 52 | + self.graph.append({"type": "Embedding", |
| 53 | + "nodes": {"input": input, "weight": weight, "output": output}}) |
| 54 | +``` |
| 55 | + |
| 56 | +### `ExpandHead` |
| 57 | +```python |
| 58 | +def ExpandHead(self, input, headDim): |
| 59 | + """ |
| 60 | + 把input最后一维展开成[-1, headDim]。 |
| 61 | +
|
| 62 | + 参数: |
| 63 | + input (GraphNode): 输入节点。 |
| 64 | + headDim (int): 头部维度大小。 |
| 65 | +
|
| 66 | + 返回: |
| 67 | + 无返回值,结果存储在内部图结构中。 |
| 68 | + """ |
| 69 | + self.graph.append({"type": "ExpandHeads", |
| 70 | + "nodes": {"input": input, "headDim": IntGraphNode(headDim)}}) |
| 71 | +``` |
| 72 | + |
| 73 | +### `FusedAttention` |
| 74 | +```python |
| 75 | +def FusedAttention(self, q, k, v, curk, curv, original, mask, output, seqLens, |
| 76 | + scale, maskType=0, unitLen=128): |
| 77 | + """ |
| 78 | + 执行Attention操作。 |
| 79 | +
|
| 80 | + 参数: |
| 81 | + q (GraphNode): 查询节点。 |
| 82 | + k (GraphNode): key cache |
| 83 | + v (GraphNode): value cache |
| 84 | + curk (GraphNode): 当前key |
| 85 | + curv (GraphNode): 当前value |
| 86 | + original (GraphNode): 原始节点,用于恢复计算后的shape |
| 87 | + mask (GraphNode): 掩码 |
| 88 | + output (GraphNode): 输出 |
| 89 | + seqLens (GraphNode): 序列长度 |
| 90 | + scale (float): 缩放因子 |
| 91 | + maskType (int, optional): 掩码类型,默认为 0。 |
| 92 | + unitLen (int, optional): 单元长度,默认为 128。 |
| 93 | +
|
| 94 | + 返回: |
| 95 | + 无返回值,结果存储在内部图结构中。 |
| 96 | + """ |
| 97 | + self.graph.append({"type": "FusedAttention", |
| 98 | + "nodes": {"q": q, "k": k, "v": v, "curk": curk, "curv": curv, |
| 99 | + "original": original, "mask": mask, "output": output, "seqLens": seqLens, |
| 100 | + "scale": FloatGraphNode(scale), |
| 101 | + "maskType": IntGraphNode(maskType), "unitLen": IntGraphNode(unitLen)}}) |
| 102 | +``` |
| 103 | + |
| 104 | +### `Linear` |
| 105 | +```python |
| 106 | +def Linear(self, input, weight, bias, output): |
| 107 | + """ |
| 108 | + 执行线性变换操作。 |
| 109 | +
|
| 110 | + 参数: |
| 111 | + input (GraphNode): 输入节点。 |
| 112 | + weight (GraphNode): 权重节点。 |
| 113 | + bias (GraphNode): 偏置节点。 |
| 114 | + output (GraphNode): 输出节点。 |
| 115 | +
|
| 116 | + 返回: |
| 117 | + 无返回值,结果存储在内部图结构中。 |
| 118 | + """ |
| 119 | + self.graph.append({"type": "Linear", |
| 120 | + "nodes": {"input": input, "weight": weight, "bias": bias, "output": output}}) |
| 121 | +``` |
| 122 | + |
| 123 | +### `LlamaRotatePosition2D` |
| 124 | +```python |
| 125 | +def LlamaRotatePosition2D(self, input, positionIds, sin, cos, rotaryDim): |
| 126 | + """ |
| 127 | + 执行 Llama 模型的二维位置旋转操作。 |
| 128 | +
|
| 129 | + 参数: |
| 130 | + input (GraphNode): 输入节点。 |
| 131 | + positionIds (GraphNode): 位置 ID 节点。 |
| 132 | + sin (GraphNode): 正弦节点。 |
| 133 | + cos (GraphNode): 余弦节点。 |
| 134 | + rotaryDim (int): 旋转维度大小。 |
| 135 | +
|
| 136 | + 返回: |
| 137 | + 无返回值,结果存储在内部图结构中。 |
| 138 | + """ |
| 139 | + self.graph.append({"type": "LlamaRotatePosition2D", |
| 140 | + "nodes": {"input": input, "positionIds": positionIds, "sin": sin, "cos": cos, "rotaryDim": IntGraphNode(rotaryDim)}}) |
| 141 | +``` |
| 142 | + |
| 143 | +### `MulTo` |
| 144 | +```python |
| 145 | +def MulTo(self, input0, input1): |
| 146 | + """ |
| 147 | + 将两个输入节点相乘。 |
| 148 | +
|
| 149 | + 参数: |
| 150 | + input0 (GraphNode): 第一个输入节点。 |
| 151 | + input1 (GraphNode): 第二个输入节点。 |
| 152 | +
|
| 153 | + 返回: |
| 154 | + 无返回值,结果存储在内部图结构中。 |
| 155 | + """ |
| 156 | + self.graph.append({"type": "MulTo", |
| 157 | + "nodes": {"input0": input0, "input1": input1}}) |
| 158 | +``` |
| 159 | + |
| 160 | +### `RMSNorm` |
| 161 | +```python |
| 162 | +def RMSNorm(self, input, weight, eps, output): |
| 163 | + """ |
| 164 | + 执行 RMS 归一化操作。 |
| 165 | +
|
| 166 | + 参数: |
| 167 | + input (GraphNode): 输入节点。 |
| 168 | + weight (GraphNode): 权重节点。 |
| 169 | + eps (float): 小常数,用于防止除零错误。 |
| 170 | + output (GraphNode): 输出节点。 |
| 171 | +
|
| 172 | + 返回: |
| 173 | + 无返回值,结果存储在内部图结构中。 |
| 174 | + """ |
| 175 | + self.graph.append({"type": "RMSNorm", |
| 176 | + "nodes": {"input": input, "weight": weight, "eps": FloatGraphNode(eps), "output": output}}) |
| 177 | +``` |
| 178 | + |
| 179 | +### `Silu` |
| 180 | +```python |
| 181 | +def Silu(self, input, output): |
| 182 | + """ |
| 183 | + 执行 SiLU(Sigmoid Linear Unit)激活函数操作。 |
| 184 | +
|
| 185 | + 参数: |
| 186 | + input (GraphNode): 输入节点。 |
| 187 | + output (GraphNode): 输出节点。 |
| 188 | +
|
| 189 | + 返回: |
| 190 | + 无返回值,结果存储在内部图结构中。 |
| 191 | + """ |
| 192 | + self.graph.append({"type": "Silu", |
| 193 | + "nodes": {"input": input, "output": output}}) |
| 194 | +``` |
| 195 | + |
| 196 | +### `SplitLastTokenStates` |
| 197 | +```python |
| 198 | +def SplitLastTokenStates(self, input, seqLens, output): |
| 199 | + """ |
| 200 | + 分割batch输入中每个batch的最后一个 token 状态。 |
| 201 | +
|
| 202 | + 参数: |
| 203 | + input (GraphNode): 输入节点。 |
| 204 | + seqLens (GraphNode): 序列长度节点。 |
| 205 | + output (GraphNode): 输出节点。 |
| 206 | +
|
| 207 | + 返回: |
| 208 | + 无返回值,结果存储在内部图结构中。 |
| 209 | + """ |
| 210 | + self.graph.append({"type": "SplitLastTokenStates", |
| 211 | + "nodes": {"input": input, "output": output, "seqLens": seqLens}}) |
| 212 | +``` |
0 commit comments