-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathCustomLLM.py
33 lines (28 loc) · 1005 Bytes
/
CustomLLM.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
'''
自定义LLM
'''
from langchain.llms.base import LLM
from typing import Any, List, Optional
from langchain.callbacks.manager import CallbackManagerForLLMRun
import requests
class CustomLLM(LLM):
endpoint: str = "***/v1" # 本地大模型的API地址
model: str = "***" #大模型名称
@property
def _llm_type(self) -> str:
return "***" #大模型名称
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
callbacks: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
headers = {"Content-Type": "application/json"}
data = {"model": self.model, "messages": [{"role": "user", "content": prompt}]}
response = requests.post(f"{self.endpoint}/chat/completions", headers=headers, json=data)
if response.status_code != 200:
return "error"
result = response.json()
text = result["choices"][0]["message"]["content"]
return text