forked from ssbuild/chatglm_finetuning
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathinfer.py
46 lines (38 loc) · 1.87 KB
/
infer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
# -*- coding: utf-8 -*-
# @Time : 2023/3/9 15:29
from deep_training.data_helper import ModelArguments, TrainingArguments, DataArguments
from transformers import HfArgumentParser
from data_utils import train_info_args, NN_DataHelper
from models import MyTransformer,ChatGLMTokenizer,LoraArguments,setup_model_profile, ChatGLMConfig
if __name__ == '__main__':
train_info_args['seed'] = None
parser = HfArgumentParser((ModelArguments, DataArguments,))
model_args, data_args = parser.parse_dict(train_info_args,allow_extra_keys=True)
setup_model_profile()
dataHelper = NN_DataHelper(model_args, None, data_args)
tokenizer: ChatGLMTokenizer
tokenizer, config, _,_ = dataHelper.load_tokenizer_and_config(
tokenizer_class_name=ChatGLMTokenizer, config_class_name=ChatGLMConfig)
assert tokenizer.eos_token_id == 130005
config.initializer_weight = False
pl_model = MyTransformer(config=config, model_args=model_args)
model = pl_model.get_llm_model()
if not model.quantized:
# 按需修改,目前只支持 4/8 bit 量化 , 可以保存量化模型
model.half().quantize(4).cuda()
else:
# 已经量化
model.half().cuda()
model = model.eval()
text_list = [
"写一个诗歌,关于冬天",
"晚上睡不着应该怎么办",
]
for input in text_list:
response, history = model.chat(tokenizer, input, history=[], max_length=2048,
eos_token_id=config.eos_token_id,
do_sample=True, top_p=0.7, temperature=0.95, )
print("input", input)
print("response", response)
# response, history = base_model.chat(tokenizer, "写一个诗歌,关于冬天", history=[],max_length=30)
# print('写一个诗歌,关于冬天',' ',response)