-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathllm.py
133 lines (101 loc) · 4.75 KB
/
llm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
from enum import Enum
from langchain.schema.language_model import BaseLanguageModel
from langchain.schema import HumanMessage, SystemMessage, AIMessage
from abc import ABC, abstractmethod
from dataclasses import dataclass
@dataclass
class ModelOutput:
message: str
is_final_output: bool = False
class IBaseClass(ABC):
@abstractmethod
def predict(self, user_input: str) -> ModelOutput:
pass
# Enumerators
OutputTypes = Enum('OutputTypes', [
'SQL',
# 'API' # This is for a later use case
])
class NLP2SQL(IBaseClass):
"""
This class implements the Natural Language to the SQL
functionality of the package
"""
system_prompt: str
schema: str
options: dict
llm: BaseLanguageModel
def __init__(self, llm, options) -> None:
self.llm = llm
self.options = options
self.schema = ''
self.system_prompt = """
You are PostGreSQL Expert. You have to respond with PostGreSQL commands for the QUESTION asked by the user based on the DATABASE SCHEMA.
Make sure to follow the 'IMPORTANT NOTE' and 'GUIDELINES' provided. ALWAYS REMEMBER TO FOLLOW 'IMPORTANT NOTE' & 'GUIDELINES', DO NOT DEVIATE FROM IT.
If you can't answer the question or if the question is irrelevant to the Database Schema, Say 'I don't know'. Don't respond anything else.
If a question is incomplete, ambiguous, unclear, or unrelated to the provided database, ASK FOR CLARIFICATION INSTEAD OF RESPONDING WITH PostGreSQL QUERY
If a question is missing any necessary information, ASK FOR CLARIFICATION INSTEAD OF RESPONDING WITH PostGreSQL QUERY
IMPORTANT NOTE:
1. Always use ILIKE operators for comparing string
DATABASE SCHEMA:
{schema}
Answer using the following Template
{{Your PostGreSQL Command}}
""".replace(' ', '').strip()
self.review_prompt = """
You are to review the content provided. Your objective is clear:
If the content contains an SQL query, extract and present only that SQL query.
If the content does not contain any SQL query, respond with 'INVALID QUERY'.
Do not provide additional information or context. Stick strictly to the above guidelines.
""".replace(' ', '').strip()
self.chat_history = []
self.memory_length = options['memory']*2 if 'memory' in options else 0
def predict(self, user_input: str) -> ModelOutput:
if len(self.schema) == 0:
return ModelOutput("Schema not loaded", True)
if len(user_input) == 0:
return ModelOutput("I'm sorry, I don't understand your question.", False)
if user_input[-1] not in '.;:?!':
user_input += '.'
system_prompt = self.system_prompt.format(schema=self.schema)
messages = [SystemMessage(content=system_prompt),
HumanMessage(content=user_input)]
response = self.llm.predict_messages(
messages=(self.chat_history + messages))
final_output = False
if self.options.get('review', False):
new_response = self.llm.predict_messages(messages=[SystemMessage(
content=self.review_prompt), HumanMessage(content="Content:\n"+response.content)])
if 'INVALID' not in new_response.content:
final_output = True
response = new_response
self.chat_history.append(HumanMessage(content=user_input))
self.chat_history.append(AIMessage(content=response.content))
if len(self.chat_history) > self.memory_length:
excess = len(self.chat_history) - self.memory_length
self.chat_history = self.chat_history[excess:]
return ModelOutput(response.content, final_output)
def override_system_prompt(self, new_system_prompt: str) -> None:
if '{schema}' in new_system_prompt:
self.system_prompt = new_system_prompt
def override_review_prompt(self, new_review_prompt: str) -> None:
self.review_prompt = new_review_prompt
def load_schema_from_file(self, file_path: str) -> bool:
with open(file_path, 'r', encoding='utf-8') as file:
contents = file.read()
self.load_schema_as_string(contents)
def load_schema_as_string(self, schema: str) -> bool:
self.schema = schema
def clear_chat_history(self) -> None:
self.chat_history = []
output_type_class_map = {
OutputTypes.SQL: NLP2SQL,
# OutputTypes.API: NLP2API,
}
def initialize_model(llm, options={}, output_type: OutputTypes = OutputTypes.SQL):
"""
Based on the Output Type the Model will be instantiated
"""
model_class = output_type_class_map[output_type]
model_instance = model_class(llm, options)
return model_instance