1
- """
2
- To reproduce Reranking experiments with using GritLM for embedding and subsequent reranking replace the AbsTaskRetrieval file in MTEB with this file.
3
- """
1
+ import copy
4
2
import json
5
3
import logging
6
4
from time import time
7
5
from typing import Dict , List
8
6
7
+ import numpy as np
9
8
from sentence_transformers import SentenceTransformer
10
9
from sentence_transformers .models import Transformer , WordEmbeddings
11
10
import os
17
16
18
17
DRES_METHODS = ["encode_queries" , "encode_corpus" ]
19
18
19
+ TEMPLATES = {
20
+ "ArguAna" : "<|user|>\n " \
21
+ "Provided two debate paragraphs, check if they are about the same topic, but contain counter-arguments.\n \n " \
22
+ "Paragraph 1: {query}\n " \
23
+ "Paragraph 2: {passage}\n \n " \
24
+ "Answer with yes if paragraph 1 and paragraph 2 are about the same topic, but contain counter-arguments; Answer with no otherwise.\n " \
25
+ "<|assistant|>\n " \
26
+ "Answer:" ,
27
+ "SciFact" : "<|user|>\n " \
28
+ "Given a query and a passage, judge whether the passage is relevant to the query or not.\n \n " \
29
+ "Query: {query}\n " \
30
+ "Passage: {passage}\n \n " \
31
+ "Answer with yes if the passage is relevant to the query, and no otherwise.\n " \
32
+ "<|assistant|>\n " \
33
+ "Answer:" ,
34
+ "NFCorpus" : "<|user|>\n " \
35
+ "Given a query and a passage, judge whether the passage is relevant to the query or not.\n \n " \
36
+ "Query: {query}\n " \
37
+ "Passage: {passage}\n \n " \
38
+ "Answer with yes if the passage is relevant to the query, and no otherwise.\n " \
39
+ "<|assistant|>\n " \
40
+ "Answer:" ,
41
+ "CQADupstackAndroidRetrieval" : "<|user|>\n " \
42
+ "Given a query and a passage, judge whether the passage is relevant to the query or not.\n \n " \
43
+ "Query: {query}\n " \
44
+ "Passage: {passage}\n \n " \
45
+ "Answer with yes if the passage is relevant to the query, and no otherwise.\n " \
46
+ "<|assistant|>\n " \
47
+ "Answer:" ,
48
+ "ClimateFEVER" : "<|user|>\n " \
49
+ "Given a query and a passage, judge whether the passage is relevant to the query or not.\n \n " \
50
+ "Query: {query}\n " \
51
+ "Passage: {passage}\n \n " \
52
+ "Answer with yes if the passage is relevant to the query, and no otherwise.\n " \
53
+ "<|assistant|>\n " \
54
+ "Answer:" ,
55
+ "CQADupstackEnglishRetrieval" : "<|user|>\n " \
56
+ "Given a query and a passage, judge whether the passage is relevant to the query or not.\n \n " \
57
+ "Query: {query}\n " \
58
+ "Passage: {passage}\n \n " \
59
+ "Answer with yes if the passage is relevant to the query, and no otherwise.\n " \
60
+ "<|assistant|>\n " \
61
+ "Answer:" ,
62
+ "CQADupstackGamingRetrieval" : "<|user|>\n " \
63
+ "Given a query and a passage, judge whether the passage is relevant to the query or not.\n \n " \
64
+ "Query: {query}\n " \
65
+ "Passage: {passage}\n \n " \
66
+ "Answer with yes if the passage is relevant to the query, and no otherwise.\n " \
67
+ "<|assistant|>\n " \
68
+ "Answer:" ,
69
+ "CQADupstackGisRetrieval" : "<|user|>\n " \
70
+ "Given a query and a passage, judge whether the passage is relevant to the query or not.\n \n " \
71
+ "Query: {query}\n " \
72
+ "Passage: {passage}\n \n " \
73
+ "Answer with yes if the passage is relevant to the query, and no otherwise.\n " \
74
+ "<|assistant|>\n " \
75
+ "Answer:" ,
76
+ "CQADupstackMathematicaRetrieval" : "<|user|>\n " \
77
+ "Given a query and a passage, judge whether the passage is relevant to the query or not.\n \n " \
78
+ "Query: {query}\n " \
79
+ "Passage: {passage}\n \n " \
80
+ "Answer with yes if the passage is relevant to the query, and no otherwise.\n " \
81
+ "<|assistant|>\n " \
82
+ "Answer:" ,
83
+ "CQADupstackPhysicsRetrieval" : "<|user|>\n " \
84
+ "Given a query and a passage, judge whether the passage is relevant to the query or not.\n \n " \
85
+ "Query: {query}\n " \
86
+ "Passage: {passage}\n \n " \
87
+ "Answer with yes if the passage is relevant to the query, and no otherwise.\n " \
88
+ "<|assistant|>\n " \
89
+ "Answer:" ,
90
+ "CQADupstackProgrammersRetrieval" : "<|user|>\n " \
91
+ "Given a query and a passage, judge whether the passage is relevant to the query or not.\n \n " \
92
+ "Query: {query}\n " \
93
+ "Passage: {passage}\n \n " \
94
+ "Answer with yes if the passage is relevant to the query, and no otherwise.\n " \
95
+ "<|assistant|>\n " \
96
+ "Answer:" ,
97
+ "CQADupstackStatsRetrieval" : "<|user|>\n " \
98
+ "Given a query and a passage, judge whether the passage is relevant to the query or not.\n \n " \
99
+ "Query: {query}\n " \
100
+ "Passage: {passage}\n \n " \
101
+ "Answer with yes if the passage is relevant to the query, and no otherwise.\n " \
102
+ "<|assistant|>\n " \
103
+ "Answer:" ,
104
+ "CQADupstackTexRetrieval" : "<|user|>\n " \
105
+ "Given a query and a passage, judge whether the passage is relevant to the query or not.\n \n " \
106
+ "Query: {query}\n " \
107
+ "Passage: {passage}\n \n " \
108
+ "Answer with yes if the passage is relevant to the query, and no otherwise.\n " \
109
+ "<|assistant|>\n " \
110
+ "Answer:" ,
111
+ "CQADupstackUnixRetrieval" : "<|user|>\n " \
112
+ "Given a query and a passage, judge whether the passage is relevant to the query or not.\n \n " \
113
+ "Query: {query}\n " \
114
+ "Passage: {passage}\n \n " \
115
+ "Answer with yes if the passage is relevant to the query, and no otherwise.\n " \
116
+ "<|assistant|>\n " \
117
+ "Answer:" ,
118
+ "CQADupstackWebmastersRetrieval" : "<|user|>\n " \
119
+ "Given a query and a passage, judge whether the passage is relevant to the query or not.\n \n " \
120
+ "Query: {query}\n " \
121
+ "Passage: {passage}\n \n " \
122
+ "Answer with yes if the passage is relevant to the query, and no otherwise.\n " \
123
+ "<|assistant|>\n " \
124
+ "Answer:" ,
125
+ "CQADupstackWordpressRetrieval" : "<|user|>\n " \
126
+ "Given a query and a passage, judge whether the passage is relevant to the query or not.\n \n " \
127
+ "Query: {query}\n " \
128
+ "Passage: {passage}\n \n " \
129
+ "Answer with yes if the passage is relevant to the query, and no otherwise.\n " \
130
+ "<|assistant|>\n " \
131
+ "Answer:" ,
132
+ "DBPedia" : "<|user|>\n " \
133
+ "Given a query and a passage, judge whether the passage is relevant to the query or not.\n \n " \
134
+ "Query: {query}\n " \
135
+ "Passage: {passage}\n \n " \
136
+ "Answer with yes if the passage is relevant to the query, and no otherwise.\n " \
137
+ "<|assistant|>\n " \
138
+ "Answer:" ,
139
+ "FEVER" : "<|user|>\n " \
140
+ "Given a query and a passage, judge whether the passage is relevant to the query or not.\n \n " \
141
+ "Query: {query}\n " \
142
+ "Passage: {passage}\n \n " \
143
+ "Answer with yes if the passage is relevant to the query, and no otherwise.\n " \
144
+ "<|assistant|>\n " \
145
+ "Answer:" ,
146
+ "FiQA2018" : "<|user|>\n " \
147
+ "Given a query and a passage, judge whether the passage is relevant to the query or not.\n \n " \
148
+ "Query: {query}\n " \
149
+ "Passage: {passage}\n \n " \
150
+ "Answer with yes if the passage is relevant to the query, and no otherwise.\n " \
151
+ "<|assistant|>\n " \
152
+ "Answer:" ,
153
+ "HotpotQA" : "<|user|>\n " \
154
+ "Given a query and a passage, judge whether the passage is relevant to the query or not.\n \n " \
155
+ "Query: {query}\n " \
156
+ "Passage: {passage}\n \n " \
157
+ "Answer with yes if the passage is relevant to the query, and no otherwise.\n " \
158
+ "<|assistant|>\n " \
159
+ "Answer:" ,
160
+ "MSMARCO" : "<|user|>\n " \
161
+ "Given a query and a passage, judge whether the passage is relevant to the query or not.\n \n " \
162
+ "Query: {query}\n " \
163
+ "Passage: {passage}\n \n " \
164
+ "Answer with yes if the passage is relevant to the query, and no otherwise.\n " \
165
+ "<|assistant|>\n " \
166
+ "Answer:" ,
167
+ "NQ" : "<|user|>\n " \
168
+ "Given a query and a passage, judge whether the passage is relevant to the query or not.\n \n " \
169
+ "Query: {query}\n " \
170
+ "Passage: {passage}\n \n " \
171
+ "Answer with yes if the passage is relevant to the query, and no otherwise.\n " \
172
+ "<|assistant|>\n " \
173
+ "Answer:" ,
174
+ "QuoraRetrieval" : "<|user|>\n " \
175
+ "Given a query and a passage, judge whether the passage is relevant to the query or not.\n \n " \
176
+ "Query: {query}\n " \
177
+ "Passage: {passage}\n \n " \
178
+ "Answer with yes if the passage is relevant to the query, and no otherwise.\n " \
179
+ "<|assistant|>\n " \
180
+ "Answer:" ,
181
+ "SCIDOCS" : "<|user|>\n " \
182
+ "Given a query and a passage, judge whether the passage is relevant to the query or not.\n \n " \
183
+ "Query: {query}\n " \
184
+ "Passage: {passage}\n \n " \
185
+ "Answer with yes if the passage is relevant to the query, and no otherwise.\n " \
186
+ "<|assistant|>\n " \
187
+ "Answer:" ,
188
+ "TRECCOVID" : "<|user|>\n " \
189
+ "Given a query and a passage, judge whether the passage is relevant to the query or not.\n \n " \
190
+ "Query: {query}\n " \
191
+ "Passage: {passage}\n \n " \
192
+ "Answer with yes if the passage is relevant to the query, and no otherwise.\n " \
193
+ "<|assistant|>\n " \
194
+ "Answer:" ,
195
+ "Touche2020" : "<|user|>\n " \
196
+ "Given a query and a passage, judge whether the passage is relevant to the query or not.\n \n " \
197
+ "Query: {query}\n " \
198
+ "Passage: {passage}\n \n " \
199
+ "Answer with yes if the passage is relevant to the query, and no otherwise.\n " \
200
+ "<|assistant|>\n " \
201
+ "Answer:" ,
202
+ }
203
+
20
204
class AbsTaskRetrieval (AbsTask ):
21
205
"""
22
206
Abstract class for re-ranking experiments.
@@ -46,6 +230,7 @@ def evaluate(
46
230
score_function = "cos_sim" ,
47
231
** kwargs
48
232
):
233
+ task_name = kwargs ['task_name' ]
49
234
sgpt2_model = model
50
235
try :
51
236
from beir .retrieval .evaluation import EvaluateRetrieval
@@ -67,7 +252,7 @@ def evaluate(
67
252
corpus_chunk_size = corpus_chunk_size if corpus_chunk_size is not None else 50000 ,
68
253
** kwargs ,
69
254
)
70
-
255
+
71
256
else :
72
257
# Distributed (multi-GPU)
73
258
from beir .retrieval .search .dense import (
@@ -82,67 +267,80 @@ def evaluate(
82
267
83
268
retriever = EvaluateRetrieval (model , score_function = score_function ) # or "cos_sim" or "dot"
84
269
start_time = time ()
85
- # if os.path.isfile('SciFact/SciFact.json'):
86
- # with open('SciFact/SciFact.json') as f:
87
- # results = json.load(f)
88
- # else:
270
+ # with open(f'qrels/{task_name}.json') as f:
271
+ # results = json.load(f)
89
272
results = retriever .retrieve (corpus , queries )
90
- # with open('SciFact/SciFact.json','w') as f:
91
- # json.dump(results,f,indent=2)
92
273
end_time = time ()
93
274
sgpt2_model = sgpt2_model .to ('cpu' )
94
275
logger .info ("Time taken to retrieve: {:.2f} seconds" .format (end_time - start_time ))
95
276
model_rerank = kwargs .get ('model_rerank' , None )
96
- template = "<|user|>\n I will provide you with {num} passages, each indicated by a numerical identifier []. " \
97
- "Rank the passages based on their relevance to the search query {query}.\n \n {passages}\n \n " \
98
- "Search Query: {query}.\n \n " \
99
- "Rank the {num} passages above based on their relevance to the search query. All the passages " \
100
- "should be included and listed using identifiers, in descending order of relevance. " \
101
- "The output format should be [] > [] > ..., e.g., [4] > [2] > ... " \
102
- "Only respond with the ranking results, do not say any word or explain.\n <|assistant|>\n "
277
+ template = TEMPLATES [task_name ]
103
278
if model_rerank is not None :
279
+ if not os .path .isdir (f"rank_cache_new/{ task_name } " ):
280
+ os .makedirs (f"rank_cache_new/{ task_name } " ,exist_ok = True )
281
+ model_rerank .tokenizer .pad_token_id = model_rerank .tokenizer .eos_token_id
104
282
model_rerank = model_rerank .cuda ()
283
+ top_k = kwargs .get ('tok_k' ,kwargs ['top_k' ])
284
+ # step_size = kwargs.get('step_size', 2)
285
+ # window_size = kwargs.get('window_size', -1)
105
286
os .environ .pop ("BIDIRECTIONAL_ATTN" )
106
287
print ("BIDIRECTIONAL_ATTN" , os .getenv ("BIDIRECTIONAL_ATTN" , False ))
107
- for qid , doc_ids in tqdm (results .items (), desc = 'reranking' ):
108
- # if os.path.isfile(f"SciFact/{qid}.json"):
109
- # with open(f"SciFact/{qid}.json") as f:
110
- # rerank_orders = json.load(f)
111
- # else:
112
- doc_ids = sorted (doc_ids .items (),key = lambda x :x [1 ],reverse = True )
113
- cur_query = queries [qid ]
114
- num = 0
115
- passages = ''
116
- cur_prompt = None
117
- all_ids = {}
118
- scores = []
119
- old_orders = []
120
- while len (model_rerank .tokenizer (template .format (num = num , query = cur_query , passages = passages ), return_tensors = "pt" )["input_ids" ][0 ])< 1900 :
121
- cur_prompt = template .format (num = num , query = cur_query , passages = passages )
122
- passages += f"[{ num } ] { corpus [doc_ids [num ][0 ]]['title' ] + ' ' + corpus [doc_ids [num ][0 ]]['text' ]} \n "
123
- old_orders .append (doc_ids [num ][0 ])
124
- all_ids [num ] = doc_ids [num ][0 ]
125
- scores .append (doc_ids [num ][1 ])
126
- num += 1
127
- inputs = model_rerank .tokenizer (cur_prompt , return_tensors = "pt" )["input_ids" ].to (model_rerank .device )
128
- generation_output = model_rerank .generate (inputs , max_new_tokens = 100 , temperature = 0.7 , do_sample = True )
129
- outputs = model_rerank .tokenizer .batch_decode (generation_output [:, inputs .shape [- 1 ]:])[0 ].strip ('</s>' ).strip ()
130
- components = outputs .split ('>' )
131
- new_orders = []
132
- for idx ,c in enumerate (components ):
133
- try :
134
- new_orders .append (all_ids [int (c .strip ().strip ('[' ).strip (']' ).strip ())])
135
- except :
136
- print (len (old_orders ),outputs )
137
- pass
138
- rerank_orders = {'old_orders' :old_orders ,'new_orders' :new_orders }
139
- # with open(f"SciFact/{qid}.json",'w') as f:
140
- # json.dump(rerank_orders,f,indent=2)
141
- cur_scores = []
142
- for i in rerank_orders ['old_orders' ]:
143
- cur_scores .append (results [qid ][i ])
144
- for i ,s in zip (rerank_orders ['new_orders' ],cur_scores ):
145
- results [qid ][i ] = s
288
+ all_qids = []
289
+ for k in results :
290
+ all_qids .append (k )
291
+ for qid in all_qids :
292
+ doc_ids = sorted (results [qid ].items (), key = lambda x : x [1 ], reverse = True )
293
+ # remove_doc_ids = [d[0] for d in doc_ids[top_k:]]
294
+ # for a_doc_id in remove_doc_ids:
295
+ # results[qid].pop(a_doc_id)
296
+ all_qids = []
297
+ for k in results :
298
+ all_qids .append (k )
299
+ bar = tqdm (range (len (all_qids )* top_k ),desc = 'reranking' )
300
+ def print_orders (l ,tag ):
301
+ order_to_print = []
302
+ for local_i ,o in enumerate (l ):
303
+ order_to_print .append ([local_i ,o ])
304
+ print (order_to_print ,tag )
305
+ for qid in all_qids :
306
+ flag = False
307
+ rerank_orders = {}
308
+ if os .path .isfile (f"rank_cache_new/{ task_name } /{ qid } .json" ):
309
+ # continue
310
+ with open (f"rank_cache_new/{ task_name } /{ qid } .json" ) as f :
311
+ rerank_orders = json .load (f )
312
+ if 'old_orders' in rerank_orders and 'new_orders' in rerank_orders :
313
+ flag = True
314
+ if not flag :
315
+ with open (f"rank_cache_new/{ task_name } /{ qid } .json" ,'w' ) as f :
316
+ json .dump ({},f ,indent = 2 )
317
+ doc_ids = sorted (results [qid ].items (),key = lambda x :x [1 ],reverse = True )
318
+ orders = [d [0 ] for d in doc_ids ]
319
+ old_orders = copy .deepcopy (orders )
320
+ new_orders = []
321
+ for a_doc_id in orders [:top_k ]:
322
+ # cut to both query and foc to 600 for ArguAna
323
+ cur_prompt = template .format (query = queries [qid ][:600 ],passage = corpus [a_doc_id ]['title' ]+ ' ' + corpus [a_doc_id ]['text' ][:600 ])
324
+ inputs = model_rerank .tokenizer (cur_prompt , return_tensors = "pt" )["input_ids" ].to (model_rerank .device )
325
+ generation_output = model_rerank .generate (inputs , max_new_tokens = 1 , temperature = 0 ,
326
+ do_sample = False , return_dict_in_generate = True ,
327
+ output_scores = True )
328
+ scores = generation_output .scores [0 ][0 ].cpu ()
329
+ new_orders .append ([a_doc_id ,scores [5081 ]]) # 708 for no, 5081 for yes
330
+ bar .update (1 )
331
+ new_orders_raw = sorted (new_orders ,key = lambda x :x [1 ],reverse = True )
332
+ new_orders = [i [0 ] for i in new_orders_raw ]
333
+ rerank_orders = {'old_orders' :old_orders ,'new_orders' :new_orders }
334
+ with open (f"rank_cache_new/{ task_name } /{ qid } .json" ,'w' ) as f :
335
+ json .dump (rerank_orders ,f ,indent = 2 )
336
+ # assert set(rerank_orders['new_orders'])==set(rerank_orders['old_orders'])
337
+ # assert set(rerank_orders['new_orders'])==set(list(results[qid].keys()))
338
+ # selected_scores = []
339
+ # for rank_id,o in enumerate(rerank_orders['new_orders']):
340
+ # selected_scores.append(results[qid][o])
341
+ # selected_scores = sorted(selected_scores,reverse=True)
342
+ # for rank_id,o in enumerate(rerank_orders['new_orders']):
343
+ # results[qid][o] += (10-rank_id)/kwargs['divisor']
146
344
os .environ ["BIDIRECTIONAL_ATTN" ] = 'true'
147
345
print ("BIDIRECTIONAL_ATTN" , os .getenv ("BIDIRECTIONAL_ATTN" , False ))
148
346
0 commit comments