Skip to content

Commit 47b7fe6

Browse files
authored
Merge pull request #26 from hongjin-su/main
Update reranker
2 parents 01b2ed4 + ccf8f67 commit 47b7fe6

File tree

1 file changed

+254
-56
lines changed

1 file changed

+254
-56
lines changed

scripts/AbsTaskRetrieval.py

+254-56
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
1-
"""
2-
To reproduce Reranking experiments with using GritLM for embedding and subsequent reranking replace the AbsTaskRetrieval file in MTEB with this file.
3-
"""
1+
import copy
42
import json
53
import logging
64
from time import time
75
from typing import Dict, List
86

7+
import numpy as np
98
from sentence_transformers import SentenceTransformer
109
from sentence_transformers.models import Transformer, WordEmbeddings
1110
import os
@@ -17,6 +16,191 @@
1716

1817
DRES_METHODS = ["encode_queries", "encode_corpus"]
1918

19+
TEMPLATES = {
20+
"ArguAna": "<|user|>\n" \
21+
"Provided two debate paragraphs, check if they are about the same topic, but contain counter-arguments.\n\n" \
22+
"Paragraph 1: {query}\n" \
23+
"Paragraph 2: {passage}\n\n" \
24+
"Answer with yes if paragraph 1 and paragraph 2 are about the same topic, but contain counter-arguments; Answer with no otherwise.\n" \
25+
"<|assistant|>\n" \
26+
"Answer:",
27+
"SciFact": "<|user|>\n" \
28+
"Given a query and a passage, judge whether the passage is relevant to the query or not.\n\n" \
29+
"Query: {query}\n" \
30+
"Passage: {passage}\n\n" \
31+
"Answer with yes if the passage is relevant to the query, and no otherwise.\n" \
32+
"<|assistant|>\n" \
33+
"Answer:",
34+
"NFCorpus": "<|user|>\n" \
35+
"Given a query and a passage, judge whether the passage is relevant to the query or not.\n\n" \
36+
"Query: {query}\n" \
37+
"Passage: {passage}\n\n" \
38+
"Answer with yes if the passage is relevant to the query, and no otherwise.\n" \
39+
"<|assistant|>\n" \
40+
"Answer:",
41+
"CQADupstackAndroidRetrieval": "<|user|>\n" \
42+
"Given a query and a passage, judge whether the passage is relevant to the query or not.\n\n" \
43+
"Query: {query}\n" \
44+
"Passage: {passage}\n\n" \
45+
"Answer with yes if the passage is relevant to the query, and no otherwise.\n" \
46+
"<|assistant|>\n" \
47+
"Answer:",
48+
"ClimateFEVER": "<|user|>\n" \
49+
"Given a query and a passage, judge whether the passage is relevant to the query or not.\n\n" \
50+
"Query: {query}\n" \
51+
"Passage: {passage}\n\n" \
52+
"Answer with yes if the passage is relevant to the query, and no otherwise.\n" \
53+
"<|assistant|>\n" \
54+
"Answer:",
55+
"CQADupstackEnglishRetrieval": "<|user|>\n" \
56+
"Given a query and a passage, judge whether the passage is relevant to the query or not.\n\n" \
57+
"Query: {query}\n" \
58+
"Passage: {passage}\n\n" \
59+
"Answer with yes if the passage is relevant to the query, and no otherwise.\n" \
60+
"<|assistant|>\n" \
61+
"Answer:",
62+
"CQADupstackGamingRetrieval": "<|user|>\n" \
63+
"Given a query and a passage, judge whether the passage is relevant to the query or not.\n\n" \
64+
"Query: {query}\n" \
65+
"Passage: {passage}\n\n" \
66+
"Answer with yes if the passage is relevant to the query, and no otherwise.\n" \
67+
"<|assistant|>\n" \
68+
"Answer:",
69+
"CQADupstackGisRetrieval": "<|user|>\n" \
70+
"Given a query and a passage, judge whether the passage is relevant to the query or not.\n\n" \
71+
"Query: {query}\n" \
72+
"Passage: {passage}\n\n" \
73+
"Answer with yes if the passage is relevant to the query, and no otherwise.\n" \
74+
"<|assistant|>\n" \
75+
"Answer:",
76+
"CQADupstackMathematicaRetrieval": "<|user|>\n" \
77+
"Given a query and a passage, judge whether the passage is relevant to the query or not.\n\n" \
78+
"Query: {query}\n" \
79+
"Passage: {passage}\n\n" \
80+
"Answer with yes if the passage is relevant to the query, and no otherwise.\n" \
81+
"<|assistant|>\n" \
82+
"Answer:",
83+
"CQADupstackPhysicsRetrieval": "<|user|>\n" \
84+
"Given a query and a passage, judge whether the passage is relevant to the query or not.\n\n" \
85+
"Query: {query}\n" \
86+
"Passage: {passage}\n\n" \
87+
"Answer with yes if the passage is relevant to the query, and no otherwise.\n" \
88+
"<|assistant|>\n" \
89+
"Answer:",
90+
"CQADupstackProgrammersRetrieval": "<|user|>\n" \
91+
"Given a query and a passage, judge whether the passage is relevant to the query or not.\n\n" \
92+
"Query: {query}\n" \
93+
"Passage: {passage}\n\n" \
94+
"Answer with yes if the passage is relevant to the query, and no otherwise.\n" \
95+
"<|assistant|>\n" \
96+
"Answer:",
97+
"CQADupstackStatsRetrieval": "<|user|>\n" \
98+
"Given a query and a passage, judge whether the passage is relevant to the query or not.\n\n" \
99+
"Query: {query}\n" \
100+
"Passage: {passage}\n\n" \
101+
"Answer with yes if the passage is relevant to the query, and no otherwise.\n" \
102+
"<|assistant|>\n" \
103+
"Answer:",
104+
"CQADupstackTexRetrieval": "<|user|>\n" \
105+
"Given a query and a passage, judge whether the passage is relevant to the query or not.\n\n" \
106+
"Query: {query}\n" \
107+
"Passage: {passage}\n\n" \
108+
"Answer with yes if the passage is relevant to the query, and no otherwise.\n" \
109+
"<|assistant|>\n" \
110+
"Answer:",
111+
"CQADupstackUnixRetrieval": "<|user|>\n" \
112+
"Given a query and a passage, judge whether the passage is relevant to the query or not.\n\n" \
113+
"Query: {query}\n" \
114+
"Passage: {passage}\n\n" \
115+
"Answer with yes if the passage is relevant to the query, and no otherwise.\n" \
116+
"<|assistant|>\n" \
117+
"Answer:",
118+
"CQADupstackWebmastersRetrieval": "<|user|>\n" \
119+
"Given a query and a passage, judge whether the passage is relevant to the query or not.\n\n" \
120+
"Query: {query}\n" \
121+
"Passage: {passage}\n\n" \
122+
"Answer with yes if the passage is relevant to the query, and no otherwise.\n" \
123+
"<|assistant|>\n" \
124+
"Answer:",
125+
"CQADupstackWordpressRetrieval": "<|user|>\n" \
126+
"Given a query and a passage, judge whether the passage is relevant to the query or not.\n\n" \
127+
"Query: {query}\n" \
128+
"Passage: {passage}\n\n" \
129+
"Answer with yes if the passage is relevant to the query, and no otherwise.\n" \
130+
"<|assistant|>\n" \
131+
"Answer:",
132+
"DBPedia": "<|user|>\n" \
133+
"Given a query and a passage, judge whether the passage is relevant to the query or not.\n\n" \
134+
"Query: {query}\n" \
135+
"Passage: {passage}\n\n" \
136+
"Answer with yes if the passage is relevant to the query, and no otherwise.\n" \
137+
"<|assistant|>\n" \
138+
"Answer:",
139+
"FEVER": "<|user|>\n" \
140+
"Given a query and a passage, judge whether the passage is relevant to the query or not.\n\n" \
141+
"Query: {query}\n" \
142+
"Passage: {passage}\n\n" \
143+
"Answer with yes if the passage is relevant to the query, and no otherwise.\n" \
144+
"<|assistant|>\n" \
145+
"Answer:",
146+
"FiQA2018": "<|user|>\n" \
147+
"Given a query and a passage, judge whether the passage is relevant to the query or not.\n\n" \
148+
"Query: {query}\n" \
149+
"Passage: {passage}\n\n" \
150+
"Answer with yes if the passage is relevant to the query, and no otherwise.\n" \
151+
"<|assistant|>\n" \
152+
"Answer:",
153+
"HotpotQA": "<|user|>\n" \
154+
"Given a query and a passage, judge whether the passage is relevant to the query or not.\n\n" \
155+
"Query: {query}\n" \
156+
"Passage: {passage}\n\n" \
157+
"Answer with yes if the passage is relevant to the query, and no otherwise.\n" \
158+
"<|assistant|>\n" \
159+
"Answer:",
160+
"MSMARCO": "<|user|>\n" \
161+
"Given a query and a passage, judge whether the passage is relevant to the query or not.\n\n" \
162+
"Query: {query}\n" \
163+
"Passage: {passage}\n\n" \
164+
"Answer with yes if the passage is relevant to the query, and no otherwise.\n" \
165+
"<|assistant|>\n" \
166+
"Answer:",
167+
"NQ": "<|user|>\n" \
168+
"Given a query and a passage, judge whether the passage is relevant to the query or not.\n\n" \
169+
"Query: {query}\n" \
170+
"Passage: {passage}\n\n" \
171+
"Answer with yes if the passage is relevant to the query, and no otherwise.\n" \
172+
"<|assistant|>\n" \
173+
"Answer:",
174+
"QuoraRetrieval": "<|user|>\n" \
175+
"Given a query and a passage, judge whether the passage is relevant to the query or not.\n\n" \
176+
"Query: {query}\n" \
177+
"Passage: {passage}\n\n" \
178+
"Answer with yes if the passage is relevant to the query, and no otherwise.\n" \
179+
"<|assistant|>\n" \
180+
"Answer:",
181+
"SCIDOCS": "<|user|>\n" \
182+
"Given a query and a passage, judge whether the passage is relevant to the query or not.\n\n" \
183+
"Query: {query}\n" \
184+
"Passage: {passage}\n\n" \
185+
"Answer with yes if the passage is relevant to the query, and no otherwise.\n" \
186+
"<|assistant|>\n" \
187+
"Answer:",
188+
"TRECCOVID": "<|user|>\n" \
189+
"Given a query and a passage, judge whether the passage is relevant to the query or not.\n\n" \
190+
"Query: {query}\n" \
191+
"Passage: {passage}\n\n" \
192+
"Answer with yes if the passage is relevant to the query, and no otherwise.\n" \
193+
"<|assistant|>\n" \
194+
"Answer:",
195+
"Touche2020": "<|user|>\n" \
196+
"Given a query and a passage, judge whether the passage is relevant to the query or not.\n\n" \
197+
"Query: {query}\n" \
198+
"Passage: {passage}\n\n" \
199+
"Answer with yes if the passage is relevant to the query, and no otherwise.\n" \
200+
"<|assistant|>\n" \
201+
"Answer:",
202+
}
203+
20204
class AbsTaskRetrieval(AbsTask):
21205
"""
22206
Abstract class for re-ranking experiments.
@@ -46,6 +230,7 @@ def evaluate(
46230
score_function="cos_sim",
47231
**kwargs
48232
):
233+
task_name = kwargs['task_name']
49234
sgpt2_model = model
50235
try:
51236
from beir.retrieval.evaluation import EvaluateRetrieval
@@ -67,7 +252,7 @@ def evaluate(
67252
corpus_chunk_size=corpus_chunk_size if corpus_chunk_size is not None else 50000,
68253
**kwargs,
69254
)
70-
255+
71256
else:
72257
# Distributed (multi-GPU)
73258
from beir.retrieval.search.dense import (
@@ -82,67 +267,80 @@ def evaluate(
82267

83268
retriever = EvaluateRetrieval(model, score_function=score_function) # or "cos_sim" or "dot"
84269
start_time = time()
85-
# if os.path.isfile('SciFact/SciFact.json'):
86-
# with open('SciFact/SciFact.json') as f:
87-
# results = json.load(f)
88-
# else:
270+
# with open(f'qrels/{task_name}.json') as f:
271+
# results = json.load(f)
89272
results = retriever.retrieve(corpus, queries)
90-
# with open('SciFact/SciFact.json','w') as f:
91-
# json.dump(results,f,indent=2)
92273
end_time = time()
93274
sgpt2_model = sgpt2_model.to('cpu')
94275
logger.info("Time taken to retrieve: {:.2f} seconds".format(end_time - start_time))
95276
model_rerank = kwargs.get('model_rerank', None)
96-
template = "<|user|>\nI will provide you with {num} passages, each indicated by a numerical identifier []. " \
97-
"Rank the passages based on their relevance to the search query {query}.\n\n{passages}\n\n" \
98-
"Search Query: {query}.\n\n" \
99-
"Rank the {num} passages above based on their relevance to the search query. All the passages " \
100-
"should be included and listed using identifiers, in descending order of relevance. " \
101-
"The output format should be [] > [] > ..., e.g., [4] > [2] > ... " \
102-
"Only respond with the ranking results, do not say any word or explain.\n<|assistant|>\n"
277+
template = TEMPLATES[task_name]
103278
if model_rerank is not None:
279+
if not os.path.isdir(f"rank_cache_new/{task_name}"):
280+
os.makedirs(f"rank_cache_new/{task_name}",exist_ok=True)
281+
model_rerank.tokenizer.pad_token_id = model_rerank.tokenizer.eos_token_id
104282
model_rerank = model_rerank.cuda()
283+
top_k = kwargs.get('tok_k',kwargs['top_k'])
284+
# step_size = kwargs.get('step_size', 2)
285+
# window_size = kwargs.get('window_size', -1)
105286
os.environ.pop("BIDIRECTIONAL_ATTN")
106287
print("BIDIRECTIONAL_ATTN", os.getenv("BIDIRECTIONAL_ATTN", False))
107-
for qid, doc_ids in tqdm(results.items(), desc='reranking'):
108-
# if os.path.isfile(f"SciFact/{qid}.json"):
109-
# with open(f"SciFact/{qid}.json") as f:
110-
# rerank_orders = json.load(f)
111-
# else:
112-
doc_ids = sorted(doc_ids.items(),key=lambda x:x[1],reverse=True)
113-
cur_query = queries[qid]
114-
num = 0
115-
passages = ''
116-
cur_prompt = None
117-
all_ids = {}
118-
scores = []
119-
old_orders = []
120-
while len(model_rerank.tokenizer(template.format(num=num, query=cur_query, passages=passages), return_tensors="pt")["input_ids"][0])<1900:
121-
cur_prompt = template.format(num=num, query=cur_query, passages=passages)
122-
passages += f"[{num}] {corpus[doc_ids[num][0]]['title'] + ' ' + corpus[doc_ids[num][0]]['text']}\n"
123-
old_orders.append(doc_ids[num][0])
124-
all_ids[num] = doc_ids[num][0]
125-
scores.append(doc_ids[num][1])
126-
num += 1
127-
inputs = model_rerank.tokenizer(cur_prompt, return_tensors="pt")["input_ids"].to(model_rerank.device)
128-
generation_output = model_rerank.generate(inputs, max_new_tokens=100, temperature=0.7, do_sample=True)
129-
outputs = model_rerank.tokenizer.batch_decode(generation_output[:, inputs.shape[-1]:])[0].strip('</s>').strip()
130-
components = outputs.split('>')
131-
new_orders = []
132-
for idx,c in enumerate(components):
133-
try:
134-
new_orders.append(all_ids[int(c.strip().strip('[').strip(']').strip())])
135-
except:
136-
print(len(old_orders),outputs)
137-
pass
138-
rerank_orders = {'old_orders':old_orders,'new_orders':new_orders}
139-
# with open(f"SciFact/{qid}.json",'w') as f:
140-
# json.dump(rerank_orders,f,indent=2)
141-
cur_scores = []
142-
for i in rerank_orders['old_orders']:
143-
cur_scores.append(results[qid][i])
144-
for i,s in zip(rerank_orders['new_orders'],cur_scores):
145-
results[qid][i] = s
288+
all_qids = []
289+
for k in results:
290+
all_qids.append(k)
291+
for qid in all_qids:
292+
doc_ids = sorted(results[qid].items(), key=lambda x: x[1], reverse=True)
293+
# remove_doc_ids = [d[0] for d in doc_ids[top_k:]]
294+
# for a_doc_id in remove_doc_ids:
295+
# results[qid].pop(a_doc_id)
296+
all_qids = []
297+
for k in results:
298+
all_qids.append(k)
299+
bar = tqdm(range(len(all_qids)*top_k),desc='reranking')
300+
def print_orders(l,tag):
301+
order_to_print = []
302+
for local_i,o in enumerate(l):
303+
order_to_print.append([local_i,o])
304+
print(order_to_print,tag)
305+
for qid in all_qids:
306+
flag = False
307+
rerank_orders = {}
308+
if os.path.isfile(f"rank_cache_new/{task_name}/{qid}.json"):
309+
# continue
310+
with open(f"rank_cache_new/{task_name}/{qid}.json") as f:
311+
rerank_orders = json.load(f)
312+
if 'old_orders' in rerank_orders and 'new_orders' in rerank_orders:
313+
flag = True
314+
if not flag:
315+
with open(f"rank_cache_new/{task_name}/{qid}.json",'w') as f:
316+
json.dump({},f,indent=2)
317+
doc_ids = sorted(results[qid].items(),key=lambda x:x[1],reverse=True)
318+
orders = [d[0] for d in doc_ids]
319+
old_orders = copy.deepcopy(orders)
320+
new_orders = []
321+
for a_doc_id in orders[:top_k]:
322+
# cut to both query and foc to 600 for ArguAna
323+
cur_prompt = template.format(query=queries[qid][:600],passage=corpus[a_doc_id]['title']+' '+corpus[a_doc_id]['text'][:600])
324+
inputs = model_rerank.tokenizer(cur_prompt, return_tensors="pt")["input_ids"].to(model_rerank.device)
325+
generation_output = model_rerank.generate(inputs, max_new_tokens=1, temperature=0,
326+
do_sample=False, return_dict_in_generate=True,
327+
output_scores=True)
328+
scores = generation_output.scores[0][0].cpu()
329+
new_orders.append([a_doc_id,scores[5081]]) # 708 for no, 5081 for yes
330+
bar.update(1)
331+
new_orders_raw = sorted(new_orders,key=lambda x:x[1],reverse=True)
332+
new_orders = [i[0] for i in new_orders_raw]
333+
rerank_orders = {'old_orders':old_orders,'new_orders':new_orders}
334+
with open(f"rank_cache_new/{task_name}/{qid}.json",'w') as f:
335+
json.dump(rerank_orders,f,indent=2)
336+
# assert set(rerank_orders['new_orders'])==set(rerank_orders['old_orders'])
337+
# assert set(rerank_orders['new_orders'])==set(list(results[qid].keys()))
338+
# selected_scores = []
339+
# for rank_id,o in enumerate(rerank_orders['new_orders']):
340+
# selected_scores.append(results[qid][o])
341+
# selected_scores = sorted(selected_scores,reverse=True)
342+
# for rank_id,o in enumerate(rerank_orders['new_orders']):
343+
# results[qid][o] += (10-rank_id)/kwargs['divisor']
146344
os.environ["BIDIRECTIONAL_ATTN"] = 'true'
147345
print("BIDIRECTIONAL_ATTN", os.getenv("BIDIRECTIONAL_ATTN", False))
148346

0 commit comments

Comments
 (0)