-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathmodel_server_normal.py
77 lines (63 loc) · 2.3 KB
/
model_server_normal.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import asyncio
import websockets
from transformers import AutoModelForCausalLM
from transformers import LlamaTokenizer
import torch
model_dir = "minicpm"
tokenizer = LlamaTokenizer.from_pretrained(model_dir, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(model_dir, trust_remote_code=True, torch_dtype=torch.bfloat16).cuda()
model = model.eval()
max_length = 4096
top_p = 0.8
temperature = 0.8
top_k = 0
int_out_token = 15
async def echo(websocket, path):
is_ordinary = False
async for message in websocket:
message = message.strip()
if not message:
continue
# new duplex session
if message == "<!new_session!>":
print("New session started.")
model.reset_chat_history()
is_ordinary = False
continue
# new ordinary session
if message == "<!new_ordinary_session!>":
print("New ordinary session started.")
model.reset_chat_history()
is_ordinary = True
continue
ret_code = model.chat(tokenizer, message, max_length=max_length, top_p=top_p,
temperature=temperature, top_k=top_k)
if ret_code == 1:
print("You have reached the max length limit. Please start a new chat!")
try:
await websocket.send("<!too_long!>")
except:
pass
streaming = True
while streaming:
server_resp = ""
for _ in range(int_out_token):
response, history = model.stream_generate()
if response is None or response in ["", "<idle>", "</s>"] or "<idle>" in response or "</s>" in response:
streaming = False
break
server_resp += response
if server_resp and server_resp not in ['<idle>', "</s>"]:
try:
await websocket.send(server_resp)
except:
pass
print(server_resp, end="")
try:
await websocket.send("<!end!>")
except:
pass
start_server = websockets.serve(echo, None, 8766)
print("Server started.")
asyncio.get_event_loop().run_until_complete(start_server)
asyncio.get_event_loop().run_forever()