Skip to content

Commit 2cd52f6

Browse files
committed
Support p-tuning-v2
1 parent dc0cdfb commit 2cd52f6

File tree

4 files changed

+49
-5
lines changed

4 files changed

+49
-5
lines changed

configs/model_config.py

+3
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@
2424
# LLM model name
2525
LLM_MODEL = "chatglm-6b"
2626

27+
# Use p-tuning-v2 PrefixEncoder
28+
USE_PTUNING_V2 = False
29+
2730
# LLM running device
2831
LLM_DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
2932

models/chatglm_llm.py

+34-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
1+
import json
2+
import os
3+
14
from langchain.llms.base import LLM
25
from typing import Optional, List
36
from langchain.llms.utils import enforce_stop_tokens
4-
from transformers import AutoTokenizer, AutoModel
7+
from transformers import AutoTokenizer, AutoModel, AutoConfig
58
import torch
69
from configs.model_config import LLM_DEVICE
710

@@ -51,15 +54,30 @@ def _call(self,
5154

5255
def load_model(self,
5356
model_name_or_path: str = "THUDM/chatglm-6b",
54-
llm_device=LLM_DEVICE):
57+
llm_device=LLM_DEVICE,
58+
use_ptuning_v2=False):
5559
self.tokenizer = AutoTokenizer.from_pretrained(
5660
model_name_or_path,
5761
trust_remote_code=True
5862
)
63+
64+
model_config = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True)
65+
66+
if use_ptuning_v2:
67+
try:
68+
prefix_encoder_file = open('ptuning-v2/config.json', 'r')
69+
prefix_encoder_config = json.loads(prefix_encoder_file.read())
70+
prefix_encoder_file.close()
71+
model_config.pre_seq_len = prefix_encoder_config['pre_seq_len']
72+
model_config.prefix_projection = prefix_encoder_config['prefix_projection']
73+
except Exception:
74+
print("加载PrefixEncoder config.json失败")
75+
5976
if torch.cuda.is_available() and llm_device.lower().startswith("cuda"):
6077
self.model = (
6178
AutoModel.from_pretrained(
6279
model_name_or_path,
80+
config=model_config,
6381
trust_remote_code=True)
6482
.half()
6583
.cuda()
@@ -68,8 +86,22 @@ def load_model(self,
6886
self.model = (
6987
AutoModel.from_pretrained(
7088
model_name_or_path,
89+
config=model_config,
7190
trust_remote_code=True)
7291
.float()
7392
.to(llm_device)
7493
)
94+
95+
if use_ptuning_v2:
96+
try:
97+
prefix_state_dict = torch.load('ptuning-v2/pytorch_model.bin')
98+
new_prefix_state_dict = {}
99+
for k, v in prefix_state_dict.items():
100+
if k.startswith("transformer.prefix_encoder."):
101+
new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v
102+
self.model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)
103+
self.model.transformer.prefix_encoder.float()
104+
except Exception:
105+
print("加载PrefixEncoder模型参数失败")
106+
75107
self.model = self.model.eval()

ptuning-v2/readme.md

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
如果使用了[p-tuning-v2](https://github.com/THUDM/ChatGLM-6B/tree/main/ptuning)方式微调了模型,可以将得到的PrefixEndoer放入此文件夹。
2+
3+
只需要放入模型的*config.json**pytorch_model.bin*
4+
5+
并在加载模型时勾选 *"使用p-tuning-v2微调过的模型"*

webui.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -53,11 +53,12 @@ def init_model():
5353
return """模型未成功加载,请重新选择后点击"加载模型"按钮"""
5454

5555

56-
def reinit_model(llm_model, embedding_model, llm_history_len, top_k, history):
56+
def reinit_model(llm_model, embedding_model, llm_history_len, use_ptuning_v2, top_k, history):
5757
try:
5858
local_doc_qa.init_cfg(llm_model=llm_model,
5959
embedding_model=embedding_model,
6060
llm_history_len=llm_history_len,
61+
use_ptuning_v2=use_ptuning_v2,
6162
top_k=top_k)
6263
model_status = """模型已成功重新加载,请选择文件后点击"加载文件"按钮"""
6364
except:
@@ -97,7 +98,7 @@ def get_vector_store(filepath, history):
9798
"""
9899

99100
init_message = """欢迎使用 langchain-ChatGLM Web UI,开始提问前,请依次如下 3 个步骤:
100-
1. 选择语言模型、Embedding 模型及相关参数后点击"重新加载模型",并等待加载完成提示
101+
1. 选择语言模型、Embedding 模型及相关参数,如果使用ptuning-v2方式微调过模型,将PrefixEncoder模型放在ptuning-v2文件夹里并勾选相关选项,然后点击"重新加载模型",并等待加载完成提示
101102
2. 上传或选择已有文件作为本地知识文档输入后点击"重新加载文档",并等待加载完成提示
102103
3. 输入要提交的问题后,点击回车提交 """
103104

@@ -127,6 +128,9 @@ def get_vector_store(filepath, history):
127128
step=1,
128129
label="LLM history len",
129130
interactive=True)
131+
use_ptuning_v2 = gr.Checkbox(USE_PTUNING_V2,
132+
label="使用p-tuning-v2微调过的模型",
133+
interactive=True)
130134
embedding_model = gr.Radio(embedding_model_dict_list,
131135
label="Embedding 模型",
132136
value=EMBEDDING_MODEL,
@@ -152,7 +156,7 @@ def get_vector_store(filepath, history):
152156
load_file_button = gr.Button("加载文件")
153157
load_model_button.click(reinit_model,
154158
show_progress=True,
155-
inputs=[llm_model, embedding_model, llm_history_len, top_k, chatbot],
159+
inputs=[llm_model, embedding_model, llm_history_len, use_ptuning_v2, top_k, chatbot],
156160
outputs=chatbot
157161
)
158162
# 将上传的文件保存到content文件夹下,并更新下拉框

0 commit comments

Comments
 (0)