Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rerank #9999

Open
wants to merge 23 commits into
base: latest-txt2kg
Choose a base branch
from
Open

Rerank #9999

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions examples/llm/tech_qa.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,10 +225,11 @@ def make_dataset(args):
# TODO add reranking NIM to VectorRAG
query_loader = RAGQueryLoader(
data=(fs, gs), seed_nodes_kwargs={"k_nodes": knn_neighsample_bs},
sampler_kwargs={"num_neighbors": [fanout] * num_hops},
local_filter=make_pcst_filter(triples, model),
sampler_kwargs={"num_neighbors": [fanout] * num_hops
}, local_filter=make_pcst_filter(triples, model),
local_filter_kwargs=local_filter_kwargs, raw_docs=context_docs,
embedded_docs=embedded_docs)
embedded_docs=embedded_docs, use_nvidia_rerank=True,
NIM_KEY_FOR_RERANK=args.NV_NIM_KEY)
total_data_list = []
extracted_triple_sizes = []
for data_point in tqdm(rawset, desc="Building un-split dataset"):
Expand Down
79 changes: 68 additions & 11 deletions torch_geometric/loader/rag_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,16 +57,22 @@ def register_feature_store(self, feature_store: FeatureStore):

class RAGQueryLoader:
"""Loader meant for making RAG queries from a remote backend."""
def __init__(self, data: Tuple[RAGFeatureStore, RAGGraphStore],
local_filter: Optional[Callable[[Data, Any], Data]] = None,
seed_nodes_kwargs: Optional[Dict[str, Any]] = None,
seed_edges_kwargs: Optional[Dict[str, Any]] = None,
sampler_kwargs: Optional[Dict[str, Any]] = None,
loader_kwargs: Optional[Dict[str, Any]] = None,
local_filter_kwargs: Optional[Dict[str, Any]] = None,
raw_docs: Optional[List[str]] = None,
embedded_docs: Optional[Tensor] = None,
k_for_docs: Optional[int] = 2):
def __init__(
self,
data: Tuple[RAGFeatureStore, RAGGraphStore],
local_filter: Optional[Callable[[Data, Any], Data]] = None,
seed_nodes_kwargs: Optional[Dict[str, Any]] = None,
seed_edges_kwargs: Optional[Dict[str, Any]] = None,
sampler_kwargs: Optional[Dict[str, Any]] = None,
loader_kwargs: Optional[Dict[str, Any]] = None,
local_filter_kwargs: Optional[Dict[str, Any]] = None,
raw_docs: Optional[List[str]] = None,
embedded_docs: Optional[Tensor] = None,
k_for_docs: Optional[int] = 2,
use_nvidia_rerank: Optional[bool] = False,
top_n_for_rerank: Optional[int] = 50,
NIM_KEY_FOR_RERANK: Optional[str] = '',
):
"""Loader meant for making queries from a remote backend.

Args:
Expand Down Expand Up @@ -94,6 +100,13 @@ def __init__(self, data: Tuple[RAGFeatureStore, RAGGraphStore],
Needs to match the `raw_docs`. Defaults to None.
k_for_docs (Optional[int], optional): top-k docs to select for vectorRAG.
(Default: :obj:`2`).
use_nvidia_rerank (Optional[bool], optional): Take the `top_n_for_rerank` docs and re-order them
before selecting the top `k_for_docs`. `top_n_for_rerank` should be greater than `k_for_docs`.
(Default: :obj:`False`).
top_n_for_rerank (Optional[int], optional): Number of docs to pass to NVIDIA reranker.
(Default: :obj:`50`).
NIM_KEY_FOR_RERANK: Optional[str], optional: NIM API Key needed for using reranker.
(Default: obj:`''`)
"""
fstore, gstore = data
self.raw_docs = raw_docs
Expand All @@ -112,6 +125,9 @@ def __init__(self, data: Tuple[RAGFeatureStore, RAGGraphStore],
self.sampler_kwargs = sampler_kwargs or {}
self.loader_kwargs = loader_kwargs or {}
self.local_filter_kwargs = local_filter_kwargs or {}
self.use_nvidia_rerank = use_nvidia_rerank
self.top_n_for_rerank = top_n_for_rerank
self.NIM_KEY_FOR_RERANK = NIM_KEY_FOR_RERANK

def query(self, query: Any) -> Data:
"""Retrieve a subgraph associated with the query with all its feature
Expand Down Expand Up @@ -154,9 +170,50 @@ def query(self, query: Any) -> Data:
if self.local_filter:
data = self.local_filter(data, query, **self.local_filter_kwargs)
if self.raw_docs:
selected_doc_idxs, _ = next(
selected_doc_idxs, _, all_idxs = next(
batch_knn(query_enc, self.embedded_docs, self.k_for_docs))
if self.use_nvidia_rerank:
topN_ids = all_idxs[:self.top_n_for_rerank]
for retry in range(10):
try:
reranked = _rerank(
query, [self.raw_docs[j] for j in topN_ids],
self.NIM_KEY_FOR_RERANK)
break
except Exception as e: # noqa
print("Retrying after", e)
print("...")
reranked_ids = topN_ids[reranked]
selected_doc_idxs = reranked_ids[:self.k_for_docs]
data.text_context = "\n".join(
[self.raw_docs[i] for i in selected_doc_idxs])

return data


def _rerank(query, passages, key):
import requests
invoke_url = "https://ai.api.nvidia.com/v1/retrieval/nvidia/llama-3_2-nv-rerankqa-1b-v2/reranking"
reranker_model_name = "nvidia/llama-3.2-nv-rerankqa-1b-v2"
headers = {
"Authorization": f"Bearer {key}",
"Accept": "application/json",
}

payload = {
"model": reranker_model_name,
"query": {
"text": query
},
"passages": [{
"text": p
} for p in passages],
"truncate":
"NONE" # No truncation, if passage is longer than context window, let it fail explicitly
}
# re-use connections
session = requests.Session()
response = session.post(invoke_url, headers=headers, json=payload)
response.raise_for_status()
response_body = response.json()
return [x['index'] for x in response_body['rankings']]
10 changes: 6 additions & 4 deletions torch_geometric/nn/nlp/txt2kg.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,9 +173,10 @@ def add_doc_2_KG(
self.NVIDIA_API_KEY, self.NIM_MODEL),
nprocs=num_procs)
break
except: # noqa
except Exception as e: # noqa
# keep retrying, txt2kg is costly -> stoppage is costly
pass
print("Retrying after", e)
print("...")

# Collect the results from each process
self.relevant_triples[key] = []
Expand All @@ -184,8 +185,9 @@ def add_doc_2_KG(
"/tmp/outs_for_proc_" + str(rank))
os.remove("/tmp/outs_for_proc_" + str(rank))
break
except:
pass
except Exception as e: # noqa
print("Retrying after", e)
print("...")
# Increment the doc_id_counter for the next document
self.doc_id_counter += 1

Expand Down
6 changes: 3 additions & 3 deletions torch_geometric/utils/rag/feature_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def batch_knn(query_enc: Tensor, embeds: Tensor,
topk = min(k, len(embeds))
for i, q in enumerate(prizes):
_, indices = torch.topk(q, topk, largest=True)
yield indices, query_enc[i].unsqueeze(0)
yield indices, query_enc[i].unsqueeze(0), torch.argsort(q)


# NOTE: Only compatible with Homogeneous graphs for now
Expand Down Expand Up @@ -68,7 +68,7 @@ def retrieve_seed_nodes(self, query: Any, k_nodes: int = 5) -> InputNodes:
Returns:
- The indices of the most similar nodes.
"""
result, query_enc = next(
result, query_enc, _ = next(
self._retrieve_seed_nodes_batch([query], k_nodes))
gc.collect()
torch.cuda.empty_cache()
Expand Down Expand Up @@ -102,7 +102,7 @@ def retrieve_seed_edges(self, query: Any, k_edges: int = 3) -> InputEdges:
Returns:
- The indices of the most similar edges.
"""
result, query_enc = next(
result, query_enc, _ = next(
self._retrieve_seed_edges_batch([query], k_edges))
gc.collect()
torch.cuda.empty_cache()
Expand Down
Loading