diff --git a/examples/llm/tech_qa.py b/examples/llm/tech_qa.py index 593bddf3bd92..23d8e9ea4432 100644 --- a/examples/llm/tech_qa.py +++ b/examples/llm/tech_qa.py @@ -225,10 +225,11 @@ def make_dataset(args): # TODO add reranking NIM to VectorRAG query_loader = RAGQueryLoader( data=(fs, gs), seed_nodes_kwargs={"k_nodes": knn_neighsample_bs}, - sampler_kwargs={"num_neighbors": [fanout] * num_hops}, - local_filter=make_pcst_filter(triples, model), + sampler_kwargs={"num_neighbors": [fanout] * num_hops + }, local_filter=make_pcst_filter(triples, model), local_filter_kwargs=local_filter_kwargs, raw_docs=context_docs, - embedded_docs=embedded_docs) + embedded_docs=embedded_docs, use_nvidia_rerank=True, + NIM_KEY_FOR_RERANK=args.NV_NIM_KEY) total_data_list = [] extracted_triple_sizes = [] for data_point in tqdm(rawset, desc="Building un-split dataset"): diff --git a/torch_geometric/loader/rag_loader.py b/torch_geometric/loader/rag_loader.py index 3fe9e4b8b6d8..a62a9178c86e 100644 --- a/torch_geometric/loader/rag_loader.py +++ b/torch_geometric/loader/rag_loader.py @@ -57,16 +57,22 @@ def register_feature_store(self, feature_store: FeatureStore): class RAGQueryLoader: """Loader meant for making RAG queries from a remote backend.""" - def __init__(self, data: Tuple[RAGFeatureStore, RAGGraphStore], - local_filter: Optional[Callable[[Data, Any], Data]] = None, - seed_nodes_kwargs: Optional[Dict[str, Any]] = None, - seed_edges_kwargs: Optional[Dict[str, Any]] = None, - sampler_kwargs: Optional[Dict[str, Any]] = None, - loader_kwargs: Optional[Dict[str, Any]] = None, - local_filter_kwargs: Optional[Dict[str, Any]] = None, - raw_docs: Optional[List[str]] = None, - embedded_docs: Optional[Tensor] = None, - k_for_docs: Optional[int] = 2): + def __init__( + self, + data: Tuple[RAGFeatureStore, RAGGraphStore], + local_filter: Optional[Callable[[Data, Any], Data]] = None, + seed_nodes_kwargs: Optional[Dict[str, Any]] = None, + seed_edges_kwargs: Optional[Dict[str, Any]] = None, + sampler_kwargs: Optional[Dict[str, Any]] = None, + loader_kwargs: Optional[Dict[str, Any]] = None, + local_filter_kwargs: Optional[Dict[str, Any]] = None, + raw_docs: Optional[List[str]] = None, + embedded_docs: Optional[Tensor] = None, + k_for_docs: Optional[int] = 2, + use_nvidia_rerank: Optional[bool] = False, + top_n_for_rerank: Optional[int] = 50, + NIM_KEY_FOR_RERANK: Optional[str] = '', + ): """Loader meant for making queries from a remote backend. Args: @@ -94,6 +100,13 @@ def __init__(self, data: Tuple[RAGFeatureStore, RAGGraphStore], Needs to match the `raw_docs`. Defaults to None. k_for_docs (Optional[int], optional): top-k docs to select for vectorRAG. (Default: :obj:`2`). + use_nvidia_rerank (Optional[bool], optional): Take the `top_n_for_rerank` docs and re-order them + before selecting the top `k_for_docs`. `top_n_for_rerank` should be greater than `k_for_docs`. + (Default: :obj:`False`). + top_n_for_rerank (Optional[int], optional): Number of docs to pass to NVIDIA reranker. + (Default: :obj:`50`). + NIM_KEY_FOR_RERANK: Optional[str], optional: NIM API Key needed for using reranker. + (Default: obj:`''`) """ fstore, gstore = data self.raw_docs = raw_docs @@ -112,6 +125,9 @@ def __init__(self, data: Tuple[RAGFeatureStore, RAGGraphStore], self.sampler_kwargs = sampler_kwargs or {} self.loader_kwargs = loader_kwargs or {} self.local_filter_kwargs = local_filter_kwargs or {} + self.use_nvidia_rerank = use_nvidia_rerank + self.top_n_for_rerank = top_n_for_rerank + self.NIM_KEY_FOR_RERANK = NIM_KEY_FOR_RERANK def query(self, query: Any) -> Data: """Retrieve a subgraph associated with the query with all its feature @@ -154,9 +170,50 @@ def query(self, query: Any) -> Data: if self.local_filter: data = self.local_filter(data, query, **self.local_filter_kwargs) if self.raw_docs: - selected_doc_idxs, _ = next( + selected_doc_idxs, _, all_idxs = next( batch_knn(query_enc, self.embedded_docs, self.k_for_docs)) + if self.use_nvidia_rerank: + topN_ids = all_idxs[:self.top_n_for_rerank] + for retry in range(10): + try: + reranked = _rerank( + query, [self.raw_docs[j] for j in topN_ids], + self.NIM_KEY_FOR_RERANK) + break + except Exception as e: # noqa + print("Retrying after", e) + print("...") + reranked_ids = topN_ids[reranked] + selected_doc_idxs = reranked_ids[:self.k_for_docs] data.text_context = "\n".join( [self.raw_docs[i] for i in selected_doc_idxs]) return data + + +def _rerank(query, passages, key): + import requests + invoke_url = "https://ai.api.nvidia.com/v1/retrieval/nvidia/llama-3_2-nv-rerankqa-1b-v2/reranking" + reranker_model_name = "nvidia/llama-3.2-nv-rerankqa-1b-v2" + headers = { + "Authorization": f"Bearer {key}", + "Accept": "application/json", + } + + payload = { + "model": reranker_model_name, + "query": { + "text": query + }, + "passages": [{ + "text": p + } for p in passages], + "truncate": + "NONE" # No truncation, if passage is longer than context window, let it fail explicitly + } + # re-use connections + session = requests.Session() + response = session.post(invoke_url, headers=headers, json=payload) + response.raise_for_status() + response_body = response.json() + return [x['index'] for x in response_body['rankings']] diff --git a/torch_geometric/nn/nlp/txt2kg.py b/torch_geometric/nn/nlp/txt2kg.py index 38ccdf3e9680..b6043b74e514 100644 --- a/torch_geometric/nn/nlp/txt2kg.py +++ b/torch_geometric/nn/nlp/txt2kg.py @@ -173,9 +173,10 @@ def add_doc_2_KG( self.NVIDIA_API_KEY, self.NIM_MODEL), nprocs=num_procs) break - except: # noqa + except Exception as e: # noqa # keep retrying, txt2kg is costly -> stoppage is costly - pass + print("Retrying after", e) + print("...") # Collect the results from each process self.relevant_triples[key] = [] @@ -184,8 +185,9 @@ def add_doc_2_KG( "/tmp/outs_for_proc_" + str(rank)) os.remove("/tmp/outs_for_proc_" + str(rank)) break - except: - pass + except Exception as e: # noqa + print("Retrying after", e) + print("...") # Increment the doc_id_counter for the next document self.doc_id_counter += 1 diff --git a/torch_geometric/utils/rag/feature_store.py b/torch_geometric/utils/rag/feature_store.py index 555b32ce6cc2..0221921c1011 100644 --- a/torch_geometric/utils/rag/feature_store.py +++ b/torch_geometric/utils/rag/feature_store.py @@ -21,7 +21,7 @@ def batch_knn(query_enc: Tensor, embeds: Tensor, topk = min(k, len(embeds)) for i, q in enumerate(prizes): _, indices = torch.topk(q, topk, largest=True) - yield indices, query_enc[i].unsqueeze(0) + yield indices, query_enc[i].unsqueeze(0), torch.argsort(q) # NOTE: Only compatible with Homogeneous graphs for now @@ -68,7 +68,7 @@ def retrieve_seed_nodes(self, query: Any, k_nodes: int = 5) -> InputNodes: Returns: - The indices of the most similar nodes. """ - result, query_enc = next( + result, query_enc, _ = next( self._retrieve_seed_nodes_batch([query], k_nodes)) gc.collect() torch.cuda.empty_cache() @@ -102,7 +102,7 @@ def retrieve_seed_edges(self, query: Any, k_edges: int = 3) -> InputEdges: Returns: - The indices of the most similar edges. """ - result, query_enc = next( + result, query_enc, _ = next( self._retrieve_seed_edges_batch([query], k_edges)) gc.collect() torch.cuda.empty_cache()