-
Notifications
You must be signed in to change notification settings - Fork 79
/
Copy pathrag.py
176 lines (155 loc) · 6.53 KB
/
rag.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
import argparse
import json
import os
import re
import time
import pandas as pd
from lotus.models import E5Model, OpenAIModel
from tag.utils import IndexMerger, row_to_str
# Taken from STaRK - https://arxiv.org/pdf/2404.13207
RERANK_PROMPT_TEMPLATE = (
"You are a helpful assistant that examines if a row "
"satisfies a given query and assign a score from 0.0 to 1.0. "
"If the row does not satisfy the query, the score should be 0.0. "
"If there exists explicit and strong evidence supporting that row "
"satisfies the query, the score should be 1.0. If partial evidence or weak "
"evidence exists, the score should be between 0.0 and 1.0.\n"
'Here is the query:\n"{query}"\n'
"Here is the information about the row:\n{row_str}\n\n"
"Please score the row based on how well it satisfies the query. "
"ONLY output the floating point score WITHOUT anything else. "
"Output: The numeric score of this row is: "
)
def run_row(args, query_row):
index_merger = db_used_to_index_merger[query_row["DB used"]]
question = query_row["Query"]
try:
answer = eval(query_row["Answer"])
except Exception:
answer = query_row["Answer"]
tic = time.time()
if args.rerank:
results = index_merger(question, args.ret_k)
for result in results:
row_str = row_to_str(result.row)
prompt = RERANK_PROMPT_TEMPLATE.format(query=question, row_str=row_str)
reranker_score = lm([[{"role": "user", "content": prompt}]])[0]
reranker_score = float(re.findall(r"[-+]?\d*\.\d+|\d+", reranker_score)[0])
result.reranker_score = reranker_score
results = sorted(results, key=lambda x: x.reranker_score, reverse=True)
else:
results = index_merger(question, args.ret_k)
user_instruction = ""
for i, result in enumerate(results):
user_instruction += f"Data Point {i+1}\n{row_to_str(result.row)}\n\n"
user_instruction += f"Question: {question}"
if query_row["Query type"] == "Aggregation":
system_instruction = (
"You will be given a list of data points and a question. Use the data points to answer the question. "
"If a value is a string, it must be enclosed in double quotes."
)
else:
system_instruction = (
"You will be given a list of data points and a question. Use the data points to answer the question. "
"Your answer must be a list of values that is evaluatable in Python. Respond in the format [value1, value2, ..., valueN]."
"If you are unable to answer the question, respond with []. Respond with only the list of values and nothing else. "
"If a value is a string, it must be enclosed in double quotes."
)
messages = [[{"role": "system", "content": system_instruction}, {"role": "user", "content": user_instruction}]]
prediction = lm(messages)[0]
latency = time.time() - tic
try:
prediction = eval(prediction)
except Exception:
print(f"Error evaluating prediction: {prediction}")
if not isinstance(answer, list):
answer = [answer]
return {
"prediction": prediction,
"answer": answer,
"query_id": query_row["Query ID"],
"latency": latency,
}
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--rerank", action="store_true")
parser.add_argument("--df_path", default="../tag_queries.csv")
parser.add_argument("--ret_k", type=int, default=5)
parser.add_argument("--output_dir")
return parser.parse_args()
if __name__ == "__main__":
args = parse_args()
queries_df = pd.read_csv(args.df_path)
lm = OpenAIModel(
model="meta-llama/Meta-Llama-3.1-70B-Instruct", api_base="http://localhost:8000/v1", provider="vllm"
)
rm = E5Model()
if args.output_dir:
os.makedirs(args.output_dir, exist_ok=True)
db_used_to_index_merger = {
"california_schools": IndexMerger(
[("california_schools", "frpm"), ("california_schools", "satscores"), ("california_schools", "schools")], rm
),
"codebase_community": IndexMerger(
[
("codebase_community", "badges"),
("codebase_community", "comments"),
("codebase_community", "postHistory"),
("codebase_community", "postLinks"),
("codebase_community", "posts"),
("codebase_community", "tags"),
("codebase_community", "users"),
("codebase_community", "votes"),
],
rm,
),
"debit_card_specializing": IndexMerger(
[
("debit_card_specializing", "customers"),
("debit_card_specializing", "gasstations"),
("debit_card_specializing", "products"),
("debit_card_specializing", "sqlite_sequence"),
("debit_card_specializing", "transactions_1k"),
("debit_card_specializing", "yearmonth"),
],
rm,
),
"european_football_2": IndexMerger(
[
("european_football_2", "Country"),
("european_football_2", "League"),
("european_football_2", "Match"),
("european_football_2", "Player"),
("european_football_2", "Player_Attributes"),
("european_football_2", "sqlite_sequence"),
("european_football_2", "Team"),
("european_football_2", "Team_Attributes"),
],
rm,
),
"formula_1": IndexMerger(
[
("formula_1", "circuits"),
("formula_1", "constructorResults"),
("formula_1", "constructors"),
("formula_1", "constructorStandings"),
("formula_1", "drivers"),
("formula_1", "driverStandings"),
("formula_1", "lapTimes"),
("formula_1", "pitStops"),
("formula_1", "qualifying"),
("formula_1", "races"),
("formula_1", "results"),
("formula_1", "seasons"),
("formula_1", "sqlite_sequence"),
("formula_1", "status"),
],
rm,
),
}
for _, query_row in queries_df.iterrows():
output = run_row(args, query_row)
print(output)
if args.output_dir:
with open(os.path.join(args.output_dir, f"query_{query_row['Query ID']}.json"), "w+") as f:
json.dump(output, f)