From 829fee50a9f132c327e7f38f890dcaaacfe29e92 Mon Sep 17 00:00:00 2001 From: Shenyang Huang Date: Thu, 8 Aug 2024 22:06:10 +0000 Subject: [PATCH 01/26] adding initial implementation of rhs transformer to repo --- .gitignore | 1 + README.md | 3 + .../relbench_link_prediction_benchmark.py | 22 ++- hybridgnn/nn/models/__init__.py | 6 +- hybridgnn/nn/models/hybrid_rhstransformer.py | 153 ++++++++++++++++++ hybridgnn/nn/models/transformer.py | 113 +++++++++++++ 6 files changed, 292 insertions(+), 6 deletions(-) create mode 100644 hybridgnn/nn/models/hybrid_rhstransformer.py create mode 100644 hybridgnn/nn/models/transformer.py diff --git a/.gitignore b/.gitignore index ced7769..bbc0e94 100644 --- a/.gitignore +++ b/.gitignore @@ -19,3 +19,4 @@ coverage.xml venv/* *.out data/** +*.txt diff --git a/README.md b/README.md index e59c177..f0354b1 100644 --- a/README.md +++ b/README.md @@ -12,6 +12,7 @@ Run [`benchmark/relbench_link_prediction_benchmark.py`](https://github.com/kumo-ai/hybridgnn/blob/master/benchmark/relbench_link_prediction_benchmark.py) ```sh +python relbench_link_prediction_benchmark.py --dataset rel-hm --task user-item-purcahse --model rhstransformer python relbench_link_prediction_benchmark.py --dataset rel-trial --task site-sponsor-run --model hybridgnn ``` @@ -31,4 +32,6 @@ pip install -e . # to run examples and benchmarks pip install -e '.[full]' + +pip install -U sentence-transformers ``` diff --git a/benchmark/relbench_link_prediction_benchmark.py b/benchmark/relbench_link_prediction_benchmark.py index 4a857fa..3a0f6ee 100644 --- a/benchmark/relbench_link_prediction_benchmark.py +++ b/benchmark/relbench_link_prediction_benchmark.py @@ -30,7 +30,7 @@ from torch_geometric.utils.cross_entropy import sparse_cross_entropy from tqdm import tqdm -from hybridgnn.nn.models import IDGNN, HybridGNN, ShallowRHSGNN +from hybridgnn.nn.models import IDGNN, HybridGNN, ShallowRHSGNN, Hybrid_RHSTransformer from hybridgnn.utils import GloveTextEmbedding TRAIN_CONFIG_KEYS = ["batch_size", "gamma_rate", "base_lr"] @@ -43,7 +43,7 @@ "--model", type=str, default="hybridgnn", - choices=["hybridgnn", "idgnn", "shallowrhsgnn"], + choices=["hybridgnn", "idgnn", "shallowrhsgnn", "rhstransformer"], ) parser.add_argument("--epochs", type=int, default=20) parser.add_argument("--num_trials", type=int, default=10, @@ -102,7 +102,7 @@ int(args.num_neighbors // 2**i) for i in range(args.num_layers) ] -model_cls: Type[Union[IDGNN, HybridGNN, ShallowRHSGNN]] +model_cls: Type[Union[IDGNN, HybridGNN, ShallowRHSGNN, Hybrid_RHSTransformer]] if args.model == "idgnn": model_search_space = { @@ -127,6 +127,18 @@ "gamma_rate": [0.9, 0.95, 1.], } model_cls = (HybridGNN if args.model == "hybridgnn" else ShallowRHSGNN) +elif args.model in ["rhstransformer"]: + model_search_space = { + "channels": [64, 128, 256], + "embedding_dim": [64, 128, 256], + "norm": ["layer_norm", "batch_norm"] + } + train_search_space = { + "batch_size": [256, 512, 1024], + "base_lr": [0.001, 0.01], + "gamma_rate": [0.9, 0.95, 1.], + } + model_cls = Hybrid_RHSTransformer def train( @@ -164,7 +176,7 @@ def train( loss = F.binary_cross_entropy_with_logits(out, target) numel = out.numel() - elif args.model in ["hybridgnn", "shallowrhsgnn"]: + elif args.model in ["hybridgnn", "shallowrhsgnn", "rhstransformer"]: logits = model(batch, task.src_entity_table, task.dst_entity_table) edge_label_index = torch.stack([src_batch, dst_index], dim=0) loss = sparse_cross_entropy(logits, edge_label_index) @@ -248,7 +260,7 @@ def train_and_eval_with_cfg( persistent_workers=args.num_workers > 0, ) - if args.model in ["hybridgnn", "shallowrhsgnn"]: + if args.model in ["hybridgnn", "shallowrhsgnn", "rhstransformer"]: model_cfg["num_nodes"] = num_dst_nodes_dict["train"] elif args.model == "idgnn": model_cfg["out_channels"] = 1 diff --git a/hybridgnn/nn/models/__init__.py b/hybridgnn/nn/models/__init__.py index c8e9ef2..f20d482 100644 --- a/hybridgnn/nn/models/__init__.py +++ b/hybridgnn/nn/models/__init__.py @@ -2,5 +2,9 @@ from .idgnn import IDGNN from .hybridgnn import HybridGNN from .shallowrhsgnn import ShallowRHSGNN +from .hybrid_rhstransformer import Hybrid_RHSTransformer -__all__ = classes = ['HeteroGraphSAGE', 'IDGNN', 'HybridGNN', 'ShallowRHSGNN'] +__all__ = classes = [ + 'HeteroGraphSAGE', 'IDGNN', 'HybridGNN', 'ShallowRHSGNN', + 'Hybrid_RHSTransformer' +] diff --git a/hybridgnn/nn/models/hybrid_rhstransformer.py b/hybridgnn/nn/models/hybrid_rhstransformer.py new file mode 100644 index 0000000..56ae32e --- /dev/null +++ b/hybridgnn/nn/models/hybrid_rhstransformer.py @@ -0,0 +1,153 @@ +from typing import Any, Dict + +import torch +from torch import Tensor +from torch_frame.data.stats import StatType +from torch_geometric.data import HeteroData +from torch_geometric.nn import MLP +from torch_geometric.typing import NodeType + +from hybridgnn.nn.encoder import ( + DEFAULT_STYPE_ENCODER_DICT, + HeteroEncoder, + HeteroTemporalEncoder, +) +from hybridgnn.nn.models import HeteroGraphSAGE +from hybridgnn.nn.models.transformer import RHSTransformer + + +class Hybrid_RHSTransformer(torch.nn.Module): + r"""Implementation of RHSTransformer model.""" + def __init__( + self, + data: HeteroData, + col_stats_dict: Dict[str, Dict[str, Dict[StatType, Any]]], + num_nodes: int, + num_layers: int, + channels: int, + embedding_dim: int, + aggr: str = 'sum', + norm: str = 'layer_norm', + pe: str = "abs", + ) -> None: + super().__init__() + + self.encoder = HeteroEncoder( + channels=channels, + node_to_col_names_dict={ + node_type: data[node_type].tf.col_names_dict + for node_type in data.node_types + }, + node_to_col_stats=col_stats_dict, + stype_encoder_cls_kwargs=DEFAULT_STYPE_ENCODER_DICT, + ) + self.temporal_encoder = HeteroTemporalEncoder( + node_types=[ + node_type for node_type in data.node_types + if "time" in data[node_type] + ], + channels=channels, + ) + self.gnn = HeteroGraphSAGE( + node_types=data.node_types, + edge_types=data.edge_types, + channels=channels, + aggr=aggr, + num_layers=num_layers, + ) + self.head = MLP( + channels, + out_channels=1, + norm=norm, + num_layers=1, + ) + self.lhs_projector = torch.nn.Linear(channels, embedding_dim) + + self.id_awareness_emb = torch.nn.Embedding(1, channels) + self.rhs_embedding = torch.nn.Embedding(num_nodes, embedding_dim) + self.lin_offset_idgnn = torch.nn.Linear(embedding_dim, 1) + self.lin_offset_embgnn = torch.nn.Linear(embedding_dim, 1) + self.rhs_transformer = RHSTransformer(in_channels=channels, + out_channels=channels, + hidden_channels=channels, + heads=1, dropout=0.2, + position_encoding=pe) + + self.channels = channels + + self.reset_parameters() + + def reset_parameters(self) -> None: + self.encoder.reset_parameters() + self.temporal_encoder.reset_parameters() + self.gnn.reset_parameters() + self.head.reset_parameters() + self.id_awareness_emb.reset_parameters() + self.rhs_embedding.reset_parameters() + self.lin_offset_embgnn.reset_parameters() + self.lin_offset_idgnn.reset_parameters() + self.lhs_projector.reset_parameters() + self.rhs_transformer.reset_parameters() + + def forward( + self, + batch: HeteroData, + entity_table: NodeType, + dst_table: NodeType, + ) -> Tensor: + seed_time = batch[entity_table].seed_time + x_dict = self.encoder(batch.tf_dict) + + # Add ID-awareness to the root node + x_dict[entity_table][:seed_time.size(0 + )] += self.id_awareness_emb.weight + rel_time_dict = self.temporal_encoder(seed_time, batch.time_dict, + batch.batch_dict) + + for node_type, rel_time in rel_time_dict.items(): + x_dict[node_type] = x_dict[node_type] + rel_time + + x_dict = self.gnn( + x_dict, + batch.edge_index_dict, + ) + + batch_size = seed_time.size(0) + lhs_embedding = x_dict[entity_table][: + batch_size] # batch_size, channel + lhs_embedding_projected = self.lhs_projector(lhs_embedding) + rhs_gnn_embedding = x_dict[dst_table] # num_sampled_rhs, channel + rhs_idgnn_index = batch.n_id_dict[dst_table] # num_sampled_rhs + lhs_idgnn_batch = batch.batch_dict[dst_table] # batch_size + + #! adding transformer here + rhs_gnn_embedding = self.rhs_transformer(rhs_gnn_embedding, + lhs_idgnn_batch) + + rhs_embedding = self.rhs_embedding # num_rhs_nodes, channel + embgnn_logits = lhs_embedding_projected @ rhs_embedding.weight.t( + ) # batch_size, num_rhs_nodes + + # Model the importance of embedding-GNN prediction for each lhs node + embgnn_offset_logits = self.lin_offset_embgnn( + lhs_embedding_projected).flatten() + embgnn_logits += embgnn_offset_logits.view(-1, 1) + + # Calculate idgnn logits + idgnn_logits = self.head( + rhs_gnn_embedding).flatten() # num_sampled_rhs + # Because we are only doing 2 hop, we are not really sampling info from + # lhs therefore, we need to incorporate this information using + # lhs_embedding[lhs_idgnn_batch] * rhs_gnn_embedding + idgnn_logits += ( + lhs_embedding[lhs_idgnn_batch] * # num_sampled_rhs, channel + rhs_gnn_embedding).sum( + dim=-1).flatten() # num_sampled_rhs, channel + + # Model the importance of ID-GNN prediction for each lhs node + idgnn_offset_logits = self.lin_offset_idgnn( + lhs_embedding_projected).flatten() + idgnn_logits = idgnn_logits + idgnn_offset_logits[lhs_idgnn_batch] + + embgnn_logits[lhs_idgnn_batch, rhs_idgnn_index] = idgnn_logits + return embgnn_logits diff --git a/hybridgnn/nn/models/transformer.py b/hybridgnn/nn/models/transformer.py new file mode 100644 index 0000000..dcab570 --- /dev/null +++ b/hybridgnn/nn/models/transformer.py @@ -0,0 +1,113 @@ +import torch +from torch import Tensor, nn +from torch_geometric.typing import EdgeType, NodeType +from torch.nested import nested_tensor + +from torch_geometric.nn.aggr.utils import MultiheadAttentionBlock +from torch_geometric.utils import to_dense_batch, to_nested_tensor, from_nested_tensor +from torch_geometric.utils import cumsum, scatter +from torch_geometric.nn.encoding import PositionalEncoding + + +class RHSTransformer(torch.nn.Module): + r"""A module to attend to rhs embeddings with a transformer. + Args: + in_channels (int): The number of input channels of the RHS embedding. + out_channels (int): The number of output channels. + hidden_channels (int): The hidden channel dimension of the transformer. + heads (int): The number of attention heads for the transformer. + num_transformer_blocks (int): The number of transformer blocks. + dropout (float): dropout rate for the transformer + """ + def __init__( + self, + in_channels: int, + out_channels: int, + hidden_channels: int = 64, + heads: int = 1, + num_transformer_blocks: int = 1, + dropout: float = 0.0, + position_encoding: str = "abs", + ) -> None: + super().__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + self.hidden_channels = hidden_channels + self.lin = torch.nn.Linear(in_channels, hidden_channels) + self.fc = torch.nn.Linear(hidden_channels, out_channels) + self.pe_type = position_encoding + self.pe = None + if (position_encoding == "abs"): + self.pe = PositionalEncoding(hidden_channels) + elif (position_encoding == "rope"): + # rotary pe for queries + self.q_pe = RotaryPositionalEmbeddings(hidden_channels) + # rotary pe for keys + self.k_pe = RotaryPositionalEmbeddings(hidden_channels) + + self.blocks = torch.nn.ModuleList([ + MultiheadAttentionBlock( + channels=hidden_channels, + heads=heads, + layer_norm=True, + dropout=dropout, + ) for _ in range(num_transformer_blocks) + ]) + + def reset_parameters(self): + for block in self.blocks: + block.reset_parameters() + self.lin.reset_parameters() + self.fc.reset_parameters() + + def forward(self, rhs_embed: Tensor, index: Tensor, + rhs_time: Tensor = None) -> Tensor: + r"""Returns the attended to rhs embeddings + """ + rhs_embed = self.lin(rhs_embed) + + if (self.pe_type == "abs"): + if (rhs_time is None): + rhs_embed = rhs_embed + self.pe( + torch.arange(rhs_embed.size(0), device=rhs_embed.device)) + else: + rhs_embed = rhs_embed + self.pe(rhs_time) + + x, mask = to_dense_batch(rhs_embed, index) + for block in self.blocks: + # apply the pe for both query and key + if (self.pe_type == "rope"): + x_q = self.q_pe(x, pos=rhs_time) + x_k = self.k_pe(x, pos=rhs_time) + else: + x_q = x + x_k = x + x = block(x_q, x_k) + x = x[mask] + x = x.view(-1, self.hidden_channels) + return self.fc(x) + + +class RotaryPositionalEmbeddings(torch.nn.Module): + def __init__(self, channels, base=10000): + super().__init__() + self.channels = channels + self.base = base + self.inv_freq = 1. / (base**(torch.arange(0, channels, 2).float() / + channels)) + + def forward(self, x, pos=None): + seq_len = x.shape[1] + if (pos is None): + pos = torch.arange(seq_len, device=x.device).type_as(self.inv_freq) + freqs = torch.einsum('i,j->ij', pos, self.inv_freq) + emb = torch.cat((freqs, freqs), dim=-1) + + cos = emb.cos().to(x.device) + sin = emb.sin().to(x.device) + + x1, x2 = x[..., ::2], x[..., 1::2] + rotated = torch.stack([-x2, x1], dim=-1).reshape(x.shape).to(x.device) + + return x * cos + rotated * sin From 12a83bd062432718765b3113234edcf9cd1feeff Mon Sep 17 00:00:00 2001 From: Shenyang Huang Date: Fri, 9 Aug 2024 18:54:43 +0000 Subject: [PATCH 02/26] adding rhs transformer to the benchmark script --- .../relbench_link_prediction_benchmark.py | 14 +++++++----- hybridgnn/nn/models/hybrid_rhstransformer.py | 22 +++++++++++++++---- hybridgnn/nn/models/transformer.py | 19 +++++----------- 3 files changed, 32 insertions(+), 23 deletions(-) diff --git a/benchmark/relbench_link_prediction_benchmark.py b/benchmark/relbench_link_prediction_benchmark.py index 3a0f6ee..470bfb7 100644 --- a/benchmark/relbench_link_prediction_benchmark.py +++ b/benchmark/relbench_link_prediction_benchmark.py @@ -129,14 +129,16 @@ model_cls = (HybridGNN if args.model == "hybridgnn" else ShallowRHSGNN) elif args.model in ["rhstransformer"]: model_search_space = { - "channels": [64, 128, 256], - "embedding_dim": [64, 128, 256], - "norm": ["layer_norm", "batch_norm"] + "channels": [64, 128], + "embedding_dim": [64, 128], + "norm": ["layer_norm"], + "dropout": [0.1, 0.2], + "pe": ["abs", "none"], } train_search_space = { "batch_size": [256, 512, 1024], - "base_lr": [0.001, 0.01], - "gamma_rate": [0.9, 0.95, 1.], + "base_lr": [0.001, 0.01, 0.0001], + "gamma_rate": [0.9, 1.0], } model_cls = Hybrid_RHSTransformer @@ -217,7 +219,7 @@ def test(model: torch.nn.Module, loader: NeighborLoader, stage: str) -> float: device=out.device) scores[batch[task.dst_entity_table].batch, batch[task.dst_entity_table].n_id] = torch.sigmoid(out) - elif args.model in ["hybridgnn", "shallowrhsgnn"]: + elif args.model in ["hybridgnn", "shallowrhsgnn", "rhstransformer"]: # Get ground-truth out = model(batch, task.src_entity_table, task.dst_entity_table).detach() diff --git a/hybridgnn/nn/models/hybrid_rhstransformer.py b/hybridgnn/nn/models/hybrid_rhstransformer.py index 56ae32e..f7c3d2a 100644 --- a/hybridgnn/nn/models/hybrid_rhstransformer.py +++ b/hybridgnn/nn/models/hybrid_rhstransformer.py @@ -17,7 +17,19 @@ class Hybrid_RHSTransformer(torch.nn.Module): - r"""Implementation of RHSTransformer model.""" + r"""Implementation of RHSTransformer model. + Args: + data (HeteroData): dataset + col_stats_dict (Dict[str, Dict[str, Dict[StatType, Any]]]): column stats + num_nodes (int): number of nodes, + num_layers (int): number of mp layers, + channels (int): input dimension, + embedding_dim (int): embedding dimension size, + aggr (str): aggregation type, + norm (norm): normalization type, + dropout (float): dropout rate for the transformer float, + heads (int): number of attention heads, + pe (str): type of positional encoding for the transformer,""" def __init__( self, data: HeteroData, @@ -28,6 +40,8 @@ def __init__( embedding_dim: int, aggr: str = 'sum', norm: str = 'layer_norm', + dropout: float = 0.2, + heads: int = 1, pe: str = "abs", ) -> None: super().__init__() @@ -62,15 +76,15 @@ def __init__( num_layers=1, ) self.lhs_projector = torch.nn.Linear(channels, embedding_dim) - self.id_awareness_emb = torch.nn.Embedding(1, channels) self.rhs_embedding = torch.nn.Embedding(num_nodes, embedding_dim) self.lin_offset_idgnn = torch.nn.Linear(embedding_dim, 1) self.lin_offset_embgnn = torch.nn.Linear(embedding_dim, 1) + self.rhs_transformer = RHSTransformer(in_channels=channels, out_channels=channels, hidden_channels=channels, - heads=1, dropout=0.2, + heads=heads, dropout=dropout, position_encoding=pe) self.channels = channels @@ -120,7 +134,7 @@ def forward( rhs_idgnn_index = batch.n_id_dict[dst_table] # num_sampled_rhs lhs_idgnn_batch = batch.batch_dict[dst_table] # batch_size - #! adding transformer here + # adding rhs transformer rhs_gnn_embedding = self.rhs_transformer(rhs_gnn_embedding, lhs_idgnn_batch) diff --git a/hybridgnn/nn/models/transformer.py b/hybridgnn/nn/models/transformer.py index dcab570..f8290db 100644 --- a/hybridgnn/nn/models/transformer.py +++ b/hybridgnn/nn/models/transformer.py @@ -18,6 +18,7 @@ class RHSTransformer(torch.nn.Module): heads (int): The number of attention heads for the transformer. num_transformer_blocks (int): The number of transformer blocks. dropout (float): dropout rate for the transformer + position_encoding (str): type of positional encoding, """ def __init__( self, @@ -40,11 +41,10 @@ def __init__( self.pe = None if (position_encoding == "abs"): self.pe = PositionalEncoding(hidden_channels) - elif (position_encoding == "rope"): - # rotary pe for queries - self.q_pe = RotaryPositionalEmbeddings(hidden_channels) - # rotary pe for keys - self.k_pe = RotaryPositionalEmbeddings(hidden_channels) + elif (position_encoding == "none"): + self.pe = None + else: + raise NotImplementedError self.blocks = torch.nn.ModuleList([ MultiheadAttentionBlock( @@ -76,14 +76,7 @@ def forward(self, rhs_embed: Tensor, index: Tensor, x, mask = to_dense_batch(rhs_embed, index) for block in self.blocks: - # apply the pe for both query and key - if (self.pe_type == "rope"): - x_q = self.q_pe(x, pos=rhs_time) - x_k = self.k_pe(x, pos=rhs_time) - else: - x_q = x - x_k = x - x = block(x_q, x_k) + x = block(x, x) x = x[mask] x = x.view(-1, self.hidden_channels) return self.fc(x) From a3a9779817326292497edbd2dc740500df326f95 Mon Sep 17 00:00:00 2001 From: shenyang Date: Mon, 12 Aug 2024 18:55:04 +0000 Subject: [PATCH 03/26] rhs transformer upload --- README.md | 6 +- .../relbench_link_prediction_benchmark.py | 22 +++++-- hybridgnn/nn/models/hybrid_rhstransformer.py | 65 ++++++++++++++++++- hybridgnn/nn/models/transformer.py | 15 ++--- 4 files changed, 89 insertions(+), 19 deletions(-) diff --git a/README.md b/README.md index f0354b1..622e1a1 100644 --- a/README.md +++ b/README.md @@ -13,15 +13,15 @@ Run [`benchmark/relbench_link_prediction_benchmark.py`](https://github.com/kumo- ```sh python relbench_link_prediction_benchmark.py --dataset rel-hm --task user-item-purcahse --model rhstransformer -python relbench_link_prediction_benchmark.py --dataset rel-trial --task site-sponsor-run --model hybridgnn +python relbench_link_prediction_benchmark.py --dataset rel-trial --task site-sponsor-run --model hybridgnn --num_layers 4 ``` Run [`examples/relbench_example.py`](https://github.com/kumo-ai/hybridgnn/blob/master/examples/relbench_example.py) ```sh -python relbench_example.py --dataset rel-trial --task site-sponsor-run --model hybridgnn -python relbench_example.py --dataset rel-trial --task condition-sponsor-run --model hybridgnn +python relbench_example.py --dataset rel-trial --task site-sponsor-run --model hybridgnn --num_layers 4 +python relbench_example.py --dataset rel-trial --task condition-sponsor-run --model hybridgnn --num_layers 4 ``` diff --git a/benchmark/relbench_link_prediction_benchmark.py b/benchmark/relbench_link_prediction_benchmark.py index 470bfb7..b326f3b 100644 --- a/benchmark/relbench_link_prediction_benchmark.py +++ b/benchmark/relbench_link_prediction_benchmark.py @@ -117,12 +117,12 @@ model_cls = IDGNN elif args.model in ["hybridgnn", "shallowrhsgnn"]: model_search_space = { - "channels": [64, 128, 256], - "embedding_dim": [64, 128, 256], + "channels": [64, 128], + "embedding_dim": [64, 128], "norm": ["layer_norm", "batch_norm"] } train_search_space = { - "batch_size": [256, 512, 1024], + "batch_size": [256], "base_lr": [0.001, 0.01], "gamma_rate": [0.9, 0.95, 1.], } @@ -136,7 +136,7 @@ "pe": ["abs", "none"], } train_search_space = { - "batch_size": [256, 512, 1024], + "batch_size": [128], "base_lr": [0.001, 0.01, 0.0001], "gamma_rate": [0.9, 1.0], } @@ -178,11 +178,16 @@ def train( loss = F.binary_cross_entropy_with_logits(out, target) numel = out.numel() - elif args.model in ["hybridgnn", "shallowrhsgnn", "rhstransformer"]: + elif args.model in ["hybridgnn", "shallowrhsgnn"]: logits = model(batch, task.src_entity_table, task.dst_entity_table) edge_label_index = torch.stack([src_batch, dst_index], dim=0) loss = sparse_cross_entropy(logits, edge_label_index) numel = len(batch[task.dst_entity_table].batch) + elif args.model in ["rhstransformer"]: + logits = model(batch, task.src_entity_table, task.dst_entity_table, task.dst_entity_col) + edge_label_index = torch.stack([src_batch, dst_index], dim=0) + loss = sparse_cross_entropy(logits, edge_label_index) + numel = len(batch[task.dst_entity_table].batch) loss.backward() optimizer.step() @@ -219,11 +224,16 @@ def test(model: torch.nn.Module, loader: NeighborLoader, stage: str) -> float: device=out.device) scores[batch[task.dst_entity_table].batch, batch[task.dst_entity_table].n_id] = torch.sigmoid(out) - elif args.model in ["hybridgnn", "shallowrhsgnn", "rhstransformer"]: + elif args.model in ["hybridgnn", "shallowrhsgnn"]: # Get ground-truth out = model(batch, task.src_entity_table, task.dst_entity_table).detach() scores = torch.sigmoid(out) + elif args.model in ["rhstransformer"]: + out = model(batch, task.src_entity_table, + task.dst_entity_table, + task.dst_entity_col).detach() + scores = torch.sigmoid(out) else: raise ValueError(f"Unsupported model type: {args.model}.") diff --git a/hybridgnn/nn/models/hybrid_rhstransformer.py b/hybridgnn/nn/models/hybrid_rhstransformer.py index f7c3d2a..a16de0a 100644 --- a/hybridgnn/nn/models/hybrid_rhstransformer.py +++ b/hybridgnn/nn/models/hybrid_rhstransformer.py @@ -14,6 +14,7 @@ ) from hybridgnn.nn.models import HeteroGraphSAGE from hybridgnn.nn.models.transformer import RHSTransformer +from torch_scatter import scatter_max class Hybrid_RHSTransformer(torch.nn.Module): @@ -108,7 +109,12 @@ def forward( batch: HeteroData, entity_table: NodeType, dst_table: NodeType, + dst_entity_col: NodeType, ) -> Tensor: + # print ("time dict has the following keys") + # print (batch.time_dict.keys()) + # dict_keys(['drop_withdrawals', 'outcomes', 'outcome_analyses', 'eligibilities', 'sponsors_studies', 'facilities_studies', 'interventions_studies', 'studies', 'designs', 'reported_event_totals', 'conditions_studies']) + seed_time = batch[entity_table].seed_time x_dict = self.encoder(batch.tf_dict) @@ -134,9 +140,12 @@ def forward( rhs_idgnn_index = batch.n_id_dict[dst_table] # num_sampled_rhs lhs_idgnn_batch = batch.batch_dict[dst_table] # batch_size + #! need custom code to work for specific datasets + # rhs_time = self.get_rhs_time_dict(batch.time_dict, batch.edge_index_dict, batch[entity_table].seed_time, batch, dst_entity_col, dst_table) + # adding rhs transformer rhs_gnn_embedding = self.rhs_transformer(rhs_gnn_embedding, - lhs_idgnn_batch) + lhs_idgnn_batch, batch_size=batch_size) rhs_embedding = self.rhs_embedding # num_rhs_nodes, channel embgnn_logits = lhs_embedding_projected @ rhs_embedding.weight.t( @@ -165,3 +174,57 @@ def forward( embgnn_logits[lhs_idgnn_batch, rhs_idgnn_index] = idgnn_logits return embgnn_logits + + def get_rhs_time_dict( + self, + time_dict, + edge_index_dict, + seed_time, + batch_dict, + dst_entity_col, + dst_entity_table, + ): + # edge_index_dict keys + """ + dict_keys([('drop_withdrawals', 'f2p_nct_id', 'studies'), + ('studies', 'rev_f2p_nct_id', 'drop_withdrawals'), + ('outcomes', 'f2p_nct_id', 'studies'), + ('studies', 'rev_f2p_nct_id', 'outcomes'), + ('outcome_analyses', 'f2p_nct_id', 'studies'), + ('studies', 'rev_f2p_nct_id', 'outcome_analyses'), + ('outcome_analyses', 'f2p_outcome_id', 'outcomes'), + ('outcomes', 'rev_f2p_outcome_id', 'outcome_analyses'), + ('eligibilities', 'f2p_nct_id', 'studies'), + ('studies', 'rev_f2p_nct_id', 'eligibilities'), + ('sponsors_studies', 'f2p_nct_id', 'studies'), + ('studies', 'rev_f2p_nct_id', 'sponsors_studies'), + ('sponsors_studies', 'f2p_sponsor_id', 'sponsors'), + ('sponsors', 'rev_f2p_sponsor_id', 'sponsors_studies'), + ('facilities_studies', 'f2p_nct_id', 'studies'), + ('studies', 'rev_f2p_nct_id', 'facilities_studies'), + ('facilities_studies', 'f2p_facility_id', 'facilities'), + ('facilities', 'rev_f2p_facility_id', 'facilities_studies'), + ('interventions_studies', 'f2p_nct_id', 'studies'), + ('studies', 'rev_f2p_nct_id', 'interventions_studies'), + ('interventions_studies', 'f2p_intervention_id', 'interventions'), + ('interventions', 'rev_f2p_intervention_id', 'interventions_studies'), + ('designs', 'f2p_nct_id', 'studies'), + ('studies', 'rev_f2p_nct_id', 'designs'), + ('reported_event_totals', 'f2p_nct_id', 'studies'), + ('studies', 'rev_f2p_nct_id', 'reported_event_totals'), + ('conditions_studies', 'f2p_nct_id', 'studies'), + ('studies', 'rev_f2p_nct_id', 'conditions_studies'), + ('conditions_studies', 'f2p_condition_id', 'conditions'), + ('conditions', 'rev_f2p_condition_id', 'conditions_studies')]) + """ + #* what to put when transaction table is merged + edge_index = edge_index_dict['sponsors','f2p_sponsor_id', + 'sponsors_studies'] + rhs_time, _ = scatter_max( + time_dict['sponsors'][edge_index[0]], + edge_index[1]) + SECONDS_IN_A_DAY = 60 * 60 * 24 + NANOSECONDS_IN_A_DAY = 60 * 60 * 24 * 1000000000 + rhs_rel_time = seed_time[batch_dict[dst_entity_col]] - rhs_time + rhs_rel_time = rhs_rel_time / NANOSECONDS_IN_A_DAY + return rhs_rel_time diff --git a/hybridgnn/nn/models/transformer.py b/hybridgnn/nn/models/transformer.py index f8290db..b189c2f 100644 --- a/hybridgnn/nn/models/transformer.py +++ b/hybridgnn/nn/models/transformer.py @@ -1,4 +1,5 @@ import torch +import math from torch import Tensor, nn from torch_geometric.typing import EdgeType, NodeType from torch.nested import nested_tensor @@ -61,20 +62,16 @@ def reset_parameters(self): self.lin.reset_parameters() self.fc.reset_parameters() - def forward(self, rhs_embed: Tensor, index: Tensor, - rhs_time: Tensor = None) -> Tensor: + + def forward(self, rhs_embed: Tensor, index: Tensor, batch_size=512) -> Tensor: r"""Returns the attended to rhs embeddings """ rhs_embed = self.lin(rhs_embed) if (self.pe_type == "abs"): - if (rhs_time is None): - rhs_embed = rhs_embed + self.pe( - torch.arange(rhs_embed.size(0), device=rhs_embed.device)) - else: - rhs_embed = rhs_embed + self.pe(rhs_time) - - x, mask = to_dense_batch(rhs_embed, index) + rhs_embed = rhs_embed + self.pe( + torch.arange(rhs_embed.size(0), device=rhs_embed.device)) + x, mask = to_dense_batch(rhs_embed, index, batch_size=batch_size) for block in self.blocks: x = block(x, x) x = x[mask] From 988781ba6c6d6e88a89dae59f6f38e8c78767008 Mon Sep 17 00:00:00 2001 From: shenyang Date: Mon, 12 Aug 2024 19:20:11 +0000 Subject: [PATCH 04/26] updating tr --- hybridgnn/nn/models/transformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hybridgnn/nn/models/transformer.py b/hybridgnn/nn/models/transformer.py index b189c2f..5f88da6 100644 --- a/hybridgnn/nn/models/transformer.py +++ b/hybridgnn/nn/models/transformer.py @@ -71,7 +71,7 @@ def forward(self, rhs_embed: Tensor, index: Tensor, batch_size=512) -> Tensor: if (self.pe_type == "abs"): rhs_embed = rhs_embed + self.pe( torch.arange(rhs_embed.size(0), device=rhs_embed.device)) - x, mask = to_dense_batch(rhs_embed, index, batch_size=batch_size) + x, mask = to_dense_batch(rhs_embed, index, max_num_nodes=batch_size) for block in self.blocks: x = block(x, x) x = x[mask] From e5faa22afa74bce3c0e6697ec060a73e44ed8b5a Mon Sep 17 00:00:00 2001 From: shenyang Date: Mon, 12 Aug 2024 21:08:33 +0000 Subject: [PATCH 05/26] running code --- benchmark/relbench_link_prediction_benchmark.py | 7 ++++--- hybridgnn/nn/models/transformer.py | 6 +++++- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/benchmark/relbench_link_prediction_benchmark.py b/benchmark/relbench_link_prediction_benchmark.py index b326f3b..984079d 100644 --- a/benchmark/relbench_link_prediction_benchmark.py +++ b/benchmark/relbench_link_prediction_benchmark.py @@ -129,14 +129,15 @@ model_cls = (HybridGNN if args.model == "hybridgnn" else ShallowRHSGNN) elif args.model in ["rhstransformer"]: model_search_space = { - "channels": [64, 128], - "embedding_dim": [64, 128], + "channels": [64], + "embedding_dim": [64], "norm": ["layer_norm"], "dropout": [0.1, 0.2], "pe": ["abs", "none"], + "num_neighbors": [args.num_neighbors], } train_search_space = { - "batch_size": [128], + "batch_size": [64], "base_lr": [0.001, 0.01, 0.0001], "gamma_rate": [0.9, 1.0], } diff --git a/hybridgnn/nn/models/transformer.py b/hybridgnn/nn/models/transformer.py index 5f88da6..40a3e8c 100644 --- a/hybridgnn/nn/models/transformer.py +++ b/hybridgnn/nn/models/transformer.py @@ -71,7 +71,11 @@ def forward(self, rhs_embed: Tensor, index: Tensor, batch_size=512) -> Tensor: if (self.pe_type == "abs"): rhs_embed = rhs_embed + self.pe( torch.arange(rhs_embed.size(0), device=rhs_embed.device)) - x, mask = to_dense_batch(rhs_embed, index, max_num_nodes=batch_size) + + sorted_index, _ = torch.sort(index) + index = sorted_index + + x, mask = to_dense_batch(rhs_embed, index, batch_size=batch_size) for block in self.blocks: x = block(x, x) x = x[mask] From 5e904eae13c5bcbc9411106f57835671aacc3cb0 Mon Sep 17 00:00:00 2001 From: shenyang Date: Mon, 12 Aug 2024 21:09:22 +0000 Subject: [PATCH 06/26] running code --- benchmark/relbench_link_prediction_benchmark.py | 1 - 1 file changed, 1 deletion(-) diff --git a/benchmark/relbench_link_prediction_benchmark.py b/benchmark/relbench_link_prediction_benchmark.py index 984079d..ddb626d 100644 --- a/benchmark/relbench_link_prediction_benchmark.py +++ b/benchmark/relbench_link_prediction_benchmark.py @@ -134,7 +134,6 @@ "norm": ["layer_norm"], "dropout": [0.1, 0.2], "pe": ["abs", "none"], - "num_neighbors": [args.num_neighbors], } train_search_space = { "batch_size": [64], From d6ff1698e29d26b298723a02640c989d5777b65a Mon Sep 17 00:00:00 2001 From: shenyang huang Date: Tue, 13 Aug 2024 20:47:16 +0000 Subject: [PATCH 07/26] adding transformer changes --- README.md | 1 + hybridgnn/nn/models/transformer.py | 5 +++-- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 622e1a1..d02d749 100644 --- a/README.md +++ b/README.md @@ -12,6 +12,7 @@ Run [`benchmark/relbench_link_prediction_benchmark.py`](https://github.com/kumo-ai/hybridgnn/blob/master/benchmark/relbench_link_prediction_benchmark.py) ```sh +python relbench_link_prediction_benchmark.py --dataset rel-stack --task user-post-comment --model rhstransformer --num_trials 10 python relbench_link_prediction_benchmark.py --dataset rel-hm --task user-item-purcahse --model rhstransformer python relbench_link_prediction_benchmark.py --dataset rel-trial --task site-sponsor-run --model hybridgnn --num_layers 4 ``` diff --git a/hybridgnn/nn/models/transformer.py b/hybridgnn/nn/models/transformer.py index 40a3e8c..502ee4e 100644 --- a/hybridgnn/nn/models/transformer.py +++ b/hybridgnn/nn/models/transformer.py @@ -72,8 +72,9 @@ def forward(self, rhs_embed: Tensor, index: Tensor, batch_size=512) -> Tensor: rhs_embed = rhs_embed + self.pe( torch.arange(rhs_embed.size(0), device=rhs_embed.device)) - sorted_index, _ = torch.sort(index) - index = sorted_index + # #! if we sort the index, we need to sort the rhs_embed + # sorted_index, _ = torch.sort(index) + # assert torch.equal(index, sorted_index) x, mask = to_dense_batch(rhs_embed, index, batch_size=batch_size) for block in self.blocks: From 4884b837b081c0f0fd363a93c6528851949511fc Mon Sep 17 00:00:00 2001 From: shenyang huang Date: Tue, 13 Aug 2024 22:19:00 +0000 Subject: [PATCH 08/26] permute the index, the rhs and then reverse it --- hybridgnn/nn/models/transformer.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/hybridgnn/nn/models/transformer.py b/hybridgnn/nn/models/transformer.py index 502ee4e..d227d8b 100644 --- a/hybridgnn/nn/models/transformer.py +++ b/hybridgnn/nn/models/transformer.py @@ -63,7 +63,7 @@ def reset_parameters(self): self.fc.reset_parameters() - def forward(self, rhs_embed: Tensor, index: Tensor, batch_size=512) -> Tensor: + def forward(self, rhs_embed: Tensor, index: Tensor, batch_size) -> Tensor: r"""Returns the attended to rhs embeddings """ rhs_embed = self.lin(rhs_embed) @@ -73,7 +73,10 @@ def forward(self, rhs_embed: Tensor, index: Tensor, batch_size=512) -> Tensor: torch.arange(rhs_embed.size(0), device=rhs_embed.device)) # #! if we sort the index, we need to sort the rhs_embed - # sorted_index, _ = torch.sort(index) + sorted_index, sorted_idx = torch.sort(index, stable=True) + index = index[sorted_idx] + rhs_embed = rhs_embed[sorted_idx] + reverse = self.inverse_permutation(sorted_idx) # assert torch.equal(index, sorted_index) x, mask = to_dense_batch(rhs_embed, index, batch_size=batch_size) @@ -81,8 +84,16 @@ def forward(self, rhs_embed: Tensor, index: Tensor, batch_size=512) -> Tensor: x = block(x, x) x = x[mask] x = x.view(-1, self.hidden_channels) + x = x[reverse] + # x = x.gather(1, sorted_idx.argsort(1)) + return self.fc(x) + def inverse_permutation(self,perm): + inv = torch.empty_like(perm) + inv[perm] = torch.arange(perm.size(0), device=perm.device) + return inv + class RotaryPositionalEmbeddings(torch.nn.Module): def __init__(self, channels, base=10000): From 08041cb73905f888ac7a2c58beabc272cc31b6c2 Mon Sep 17 00:00:00 2001 From: shenyang huang Date: Tue, 13 Aug 2024 22:21:15 +0000 Subject: [PATCH 09/26] removing none to replace with None --- .../relbench_link_prediction_benchmark.py | 2 +- hybridgnn/nn/models/hybrid_rhstransformer.py | 33 ------------------- hybridgnn/nn/models/transformer.py | 2 +- 3 files changed, 2 insertions(+), 35 deletions(-) diff --git a/benchmark/relbench_link_prediction_benchmark.py b/benchmark/relbench_link_prediction_benchmark.py index ddb626d..2586e86 100644 --- a/benchmark/relbench_link_prediction_benchmark.py +++ b/benchmark/relbench_link_prediction_benchmark.py @@ -133,7 +133,7 @@ "embedding_dim": [64], "norm": ["layer_norm"], "dropout": [0.1, 0.2], - "pe": ["abs", "none"], + "pe": ["abs", None], } train_search_space = { "batch_size": [64], diff --git a/hybridgnn/nn/models/hybrid_rhstransformer.py b/hybridgnn/nn/models/hybrid_rhstransformer.py index a16de0a..a22202d 100644 --- a/hybridgnn/nn/models/hybrid_rhstransformer.py +++ b/hybridgnn/nn/models/hybrid_rhstransformer.py @@ -184,39 +184,6 @@ def get_rhs_time_dict( dst_entity_col, dst_entity_table, ): - # edge_index_dict keys - """ - dict_keys([('drop_withdrawals', 'f2p_nct_id', 'studies'), - ('studies', 'rev_f2p_nct_id', 'drop_withdrawals'), - ('outcomes', 'f2p_nct_id', 'studies'), - ('studies', 'rev_f2p_nct_id', 'outcomes'), - ('outcome_analyses', 'f2p_nct_id', 'studies'), - ('studies', 'rev_f2p_nct_id', 'outcome_analyses'), - ('outcome_analyses', 'f2p_outcome_id', 'outcomes'), - ('outcomes', 'rev_f2p_outcome_id', 'outcome_analyses'), - ('eligibilities', 'f2p_nct_id', 'studies'), - ('studies', 'rev_f2p_nct_id', 'eligibilities'), - ('sponsors_studies', 'f2p_nct_id', 'studies'), - ('studies', 'rev_f2p_nct_id', 'sponsors_studies'), - ('sponsors_studies', 'f2p_sponsor_id', 'sponsors'), - ('sponsors', 'rev_f2p_sponsor_id', 'sponsors_studies'), - ('facilities_studies', 'f2p_nct_id', 'studies'), - ('studies', 'rev_f2p_nct_id', 'facilities_studies'), - ('facilities_studies', 'f2p_facility_id', 'facilities'), - ('facilities', 'rev_f2p_facility_id', 'facilities_studies'), - ('interventions_studies', 'f2p_nct_id', 'studies'), - ('studies', 'rev_f2p_nct_id', 'interventions_studies'), - ('interventions_studies', 'f2p_intervention_id', 'interventions'), - ('interventions', 'rev_f2p_intervention_id', 'interventions_studies'), - ('designs', 'f2p_nct_id', 'studies'), - ('studies', 'rev_f2p_nct_id', 'designs'), - ('reported_event_totals', 'f2p_nct_id', 'studies'), - ('studies', 'rev_f2p_nct_id', 'reported_event_totals'), - ('conditions_studies', 'f2p_nct_id', 'studies'), - ('studies', 'rev_f2p_nct_id', 'conditions_studies'), - ('conditions_studies', 'f2p_condition_id', 'conditions'), - ('conditions', 'rev_f2p_condition_id', 'conditions_studies')]) - """ #* what to put when transaction table is merged edge_index = edge_index_dict['sponsors','f2p_sponsor_id', 'sponsors_studies'] diff --git a/hybridgnn/nn/models/transformer.py b/hybridgnn/nn/models/transformer.py index d227d8b..f7e1f14 100644 --- a/hybridgnn/nn/models/transformer.py +++ b/hybridgnn/nn/models/transformer.py @@ -42,7 +42,7 @@ def __init__( self.pe = None if (position_encoding == "abs"): self.pe = PositionalEncoding(hidden_channels) - elif (position_encoding == "none"): + elif (position_encoding is None): self.pe = None else: raise NotImplementedError From 29c70ad25c799ddf239f65314846a8f2484d6039 Mon Sep 17 00:00:00 2001 From: shenyang huang Date: Wed, 14 Aug 2024 19:41:10 +0000 Subject: [PATCH 10/26] add time fuse encoder to extract time pe --- .../relbench_link_prediction_benchmark.py | 6 +++--- hybridgnn/nn/encoder.py | 20 +++++++++++++++++-- hybridgnn/nn/models/transformer.py | 3 --- 3 files changed, 21 insertions(+), 8 deletions(-) diff --git a/benchmark/relbench_link_prediction_benchmark.py b/benchmark/relbench_link_prediction_benchmark.py index 2586e86..080ccf5 100644 --- a/benchmark/relbench_link_prediction_benchmark.py +++ b/benchmark/relbench_link_prediction_benchmark.py @@ -133,11 +133,11 @@ "embedding_dim": [64], "norm": ["layer_norm"], "dropout": [0.1, 0.2], - "pe": ["abs", None], + "pe": [None], } train_search_space = { - "batch_size": [64], - "base_lr": [0.001, 0.01, 0.0001], + "batch_size": [512], + "base_lr": [0.0005, 0.01], "gamma_rate": [0.9, 1.0], } model_cls = Hybrid_RHSTransformer diff --git a/hybridgnn/nn/encoder.py b/hybridgnn/nn/encoder.py index 022b30a..c56ad00 100644 --- a/hybridgnn/nn/encoder.py +++ b/hybridgnn/nn/encoder.py @@ -19,6 +19,9 @@ torch_frame.timestamp: (torch_frame.nn.TimestampEncoder, {}), } SECONDS_IN_A_DAY = 60 * 60 * 24 +SECONDS_IN_A_WEEK = 7 * 60 * 60 * 24 +SECONDS_IN_A_HOUR = 60 * 60 +SECONDS_IN_A_MINUTE = 60 class HeteroEncoder(torch.nn.Module): @@ -109,11 +112,15 @@ def __init__(self, node_types: List[NodeType], channels: int) -> None: for node_type in node_types }) + time_dim = 3 # hour, day, week + self.time_fuser = torch.nn.Linear(time_dim, channels) + def reset_parameters(self) -> None: for encoder in self.encoder_dict.values(): encoder.reset_parameters() for lin in self.lin_dict.values(): lin.reset_parameters() + self.time_fuser.reset_parameters() def forward( self, @@ -125,9 +132,18 @@ def forward( for node_type, time in time_dict.items(): rel_time = seed_time[batch_dict[node_type]] - time - rel_time = rel_time / SECONDS_IN_A_DAY - x = self.encoder_dict[node_type](rel_time) + # rel_day = rel_time / SECONDS_IN_A_DAY + # x = self.encoder_dict[node_type](rel_day) + # x = self.encoder_dict[node_type](rel_hour) + rel_hour = (rel_time // SECONDS_IN_A_HOUR).view(-1,1) + rel_day = (rel_time // SECONDS_IN_A_DAY).view(-1,1) + rel_week = (rel_time // SECONDS_IN_A_WEEK).view(-1,1) + time_embed = torch.cat((rel_hour, rel_day, rel_week),dim=1).float() + + #! might need to normalize hour, day, week into the same scale + time_embed = torch.nn.functional.normalize(time_embed, p=2.0, dim=1) + x = self.time_fuser(time_embed) x = self.lin_dict[node_type](x) out_dict[node_type] = x diff --git a/hybridgnn/nn/models/transformer.py b/hybridgnn/nn/models/transformer.py index f7e1f14..a9cf2ac 100644 --- a/hybridgnn/nn/models/transformer.py +++ b/hybridgnn/nn/models/transformer.py @@ -77,7 +77,6 @@ def forward(self, rhs_embed: Tensor, index: Tensor, batch_size) -> Tensor: index = index[sorted_idx] rhs_embed = rhs_embed[sorted_idx] reverse = self.inverse_permutation(sorted_idx) - # assert torch.equal(index, sorted_index) x, mask = to_dense_batch(rhs_embed, index, batch_size=batch_size) for block in self.blocks: @@ -85,8 +84,6 @@ def forward(self, rhs_embed: Tensor, index: Tensor, batch_size) -> Tensor: x = x[mask] x = x.view(-1, self.hidden_channels) x = x[reverse] - # x = x.gather(1, sorted_idx.argsort(1)) - return self.fc(x) def inverse_permutation(self,perm): From f58b52844b5bdabb5b1b995cc1e5beb1d13c9828 Mon Sep 17 00:00:00 2001 From: shenyang huang Date: Wed, 14 Aug 2024 19:42:48 +0000 Subject: [PATCH 11/26] update hyperparameter options --- benchmark/relbench_link_prediction_benchmark.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/benchmark/relbench_link_prediction_benchmark.py b/benchmark/relbench_link_prediction_benchmark.py index 080ccf5..5331f7c 100644 --- a/benchmark/relbench_link_prediction_benchmark.py +++ b/benchmark/relbench_link_prediction_benchmark.py @@ -129,14 +129,14 @@ model_cls = (HybridGNN if args.model == "hybridgnn" else ShallowRHSGNN) elif args.model in ["rhstransformer"]: model_search_space = { - "channels": [64], - "embedding_dim": [64], + "channels": [64, 128], + "embedding_dim": [64, 128], "norm": ["layer_norm"], "dropout": [0.1, 0.2], "pe": [None], } train_search_space = { - "batch_size": [512], + "batch_size": [128, 256, 512], "base_lr": [0.0005, 0.01], "gamma_rate": [0.9, 1.0], } From 8f4966cec374bb76797edfc47160d1e0036d2a77 Mon Sep 17 00:00:00 2001 From: shenyang huang Date: Mon, 19 Aug 2024 22:54:33 +0000 Subject: [PATCH 12/26] adding rerank_transformer --- .../relbench_link_prediction_benchmark.py | 35 ++- hybridgnn/nn/models/__init__.py | 3 +- hybridgnn/nn/models/hybrid_rhstransformer.py | 4 - hybridgnn/nn/models/rerank_transformer.py | 216 ++++++++++++++++++ 4 files changed, 248 insertions(+), 10 deletions(-) create mode 100644 hybridgnn/nn/models/rerank_transformer.py diff --git a/benchmark/relbench_link_prediction_benchmark.py b/benchmark/relbench_link_prediction_benchmark.py index 5331f7c..715c77f 100644 --- a/benchmark/relbench_link_prediction_benchmark.py +++ b/benchmark/relbench_link_prediction_benchmark.py @@ -30,7 +30,7 @@ from torch_geometric.utils.cross_entropy import sparse_cross_entropy from tqdm import tqdm -from hybridgnn.nn.models import IDGNN, HybridGNN, ShallowRHSGNN, Hybrid_RHSTransformer +from hybridgnn.nn.models import IDGNN, HybridGNN, ShallowRHSGNN, Hybrid_RHSTransformer, ReRankTransformer from hybridgnn.utils import GloveTextEmbedding TRAIN_CONFIG_KEYS = ["batch_size", "gamma_rate", "base_lr"] @@ -43,7 +43,7 @@ "--model", type=str, default="hybridgnn", - choices=["hybridgnn", "idgnn", "shallowrhsgnn", "rhstransformer"], + choices=["hybridgnn", "idgnn", "shallowrhsgnn", "rhstransformer", "rerank_transformer"], ) parser.add_argument("--epochs", type=int, default=20) parser.add_argument("--num_trials", type=int, default=10, @@ -102,7 +102,7 @@ int(args.num_neighbors // 2**i) for i in range(args.num_layers) ] -model_cls: Type[Union[IDGNN, HybridGNN, ShallowRHSGNN, Hybrid_RHSTransformer]] +model_cls: Type[Union[IDGNN, HybridGNN, ShallowRHSGNN, Hybrid_RHSTransformer, ReRankTransformer]] if args.model == "idgnn": model_search_space = { @@ -141,7 +141,20 @@ "gamma_rate": [0.9, 1.0], } model_cls = Hybrid_RHSTransformer - +elif args.model in ["rerank_transformer"]: + model_search_space = { + "channels": [64], + "embedding_dim": [64], + "norm": ["layer_norm"], + "dropout": [0.0, 0.1, 0.2], + "rank_topk": [25,50,100] + } + train_search_space = { + "batch_size": [128, 256, 512], + "base_lr": [0.0005, 0.01], + "gamma_rate": [0.9, 1.0], + } + model_cls = ReRankTransformer def train( model: torch.nn.Module, @@ -188,6 +201,13 @@ def train( edge_label_index = torch.stack([src_batch, dst_index], dim=0) loss = sparse_cross_entropy(logits, edge_label_index) numel = len(batch[task.dst_entity_table].batch) + elif args.model in ["rerank_transformer"]: + gnn_logits, tr_logits, topk_idx = model(batch, task.src_entity_table, task.dst_entity_table, task.dst_entity_col) + edge_label_index = torch.stack([src_batch, dst_index], dim=0) + loss = sparse_cross_entropy(gnn_logits, edge_label_index) + loss += sparse_cross_entropy(tr_logits, edge_label_index) + numel = len(batch[task.dst_entity_table].batch) + loss.backward() optimizer.step() @@ -234,6 +254,11 @@ def test(model: torch.nn.Module, loader: NeighborLoader, stage: str) -> float: task.dst_entity_table, task.dst_entity_col).detach() scores = torch.sigmoid(out) + elif args.model in ["rerank_transformer"]: + _, out, _ = model(batch, task.src_entity_table, + task.dst_entity_table, + task.dst_entity_col) + scores = torch.sigmoid(out.detach()) else: raise ValueError(f"Unsupported model type: {args.model}.") @@ -272,7 +297,7 @@ def train_and_eval_with_cfg( persistent_workers=args.num_workers > 0, ) - if args.model in ["hybridgnn", "shallowrhsgnn", "rhstransformer"]: + if args.model in ["hybridgnn", "shallowrhsgnn", "rhstransformer", "rerank_transformer"]: model_cfg["num_nodes"] = num_dst_nodes_dict["train"] elif args.model == "idgnn": model_cfg["out_channels"] = 1 diff --git a/hybridgnn/nn/models/__init__.py b/hybridgnn/nn/models/__init__.py index f20d482..5d9f3cf 100644 --- a/hybridgnn/nn/models/__init__.py +++ b/hybridgnn/nn/models/__init__.py @@ -3,8 +3,9 @@ from .hybridgnn import HybridGNN from .shallowrhsgnn import ShallowRHSGNN from .hybrid_rhstransformer import Hybrid_RHSTransformer +from .rerank_transformer import ReRankTransformer __all__ = classes = [ 'HeteroGraphSAGE', 'IDGNN', 'HybridGNN', 'ShallowRHSGNN', - 'Hybrid_RHSTransformer' + 'Hybrid_RHSTransformer', 'ReRankTransformer' ] diff --git a/hybridgnn/nn/models/hybrid_rhstransformer.py b/hybridgnn/nn/models/hybrid_rhstransformer.py index a22202d..e373ef4 100644 --- a/hybridgnn/nn/models/hybrid_rhstransformer.py +++ b/hybridgnn/nn/models/hybrid_rhstransformer.py @@ -111,10 +111,6 @@ def forward( dst_table: NodeType, dst_entity_col: NodeType, ) -> Tensor: - # print ("time dict has the following keys") - # print (batch.time_dict.keys()) - # dict_keys(['drop_withdrawals', 'outcomes', 'outcome_analyses', 'eligibilities', 'sponsors_studies', 'facilities_studies', 'interventions_studies', 'studies', 'designs', 'reported_event_totals', 'conditions_studies']) - seed_time = batch[entity_table].seed_time x_dict = self.encoder(batch.tf_dict) diff --git a/hybridgnn/nn/models/rerank_transformer.py b/hybridgnn/nn/models/rerank_transformer.py new file mode 100644 index 0000000..e2760fc --- /dev/null +++ b/hybridgnn/nn/models/rerank_transformer.py @@ -0,0 +1,216 @@ +from typing import Any, Dict + +import torch +from torch import Tensor +from torch_frame.data.stats import StatType +from torch_geometric.data import HeteroData +from torch_geometric.nn import MLP +from torch_geometric.typing import NodeType + +from hybridgnn.nn.encoder import ( + DEFAULT_STYPE_ENCODER_DICT, + HeteroEncoder, + HeteroTemporalEncoder, +) +from hybridgnn.nn.models import HeteroGraphSAGE +from torch_scatter import scatter_max +from torch_geometric.nn.aggr.utils import MultiheadAttentionBlock +from torch_geometric.utils import to_dense_batch + + + +class ReRankTransformer(torch.nn.Module): + r"""Implementation of ReRank Transformer model. + Args: + data (HeteroData): dataset + col_stats_dict (Dict[str, Dict[str, Dict[StatType, Any]]]): column stats + num_nodes (int): number of nodes, + num_layers (int): number of mp layers, + channels (int): input dimension, + embedding_dim (int): embedding dimension size, + aggr (str): aggregation type, + norm (norm): normalization type, + dropout (float): dropout rate for the transformer float, + heads (int): number of attention heads, + rank_topk (int): how many top results of gnn would be reranked,""" + def __init__( + self, + data: HeteroData, + col_stats_dict: Dict[str, Dict[str, Dict[StatType, Any]]], + num_nodes: int, + num_layers: int, + channels: int, + embedding_dim: int, + aggr: str = 'sum', + norm: str = 'layer_norm', + dropout: float = 0.2, + heads: int = 1, + rank_topk: int = 100, + ) -> None: + super().__init__() + + self.encoder = HeteroEncoder( + channels=channels, + node_to_col_names_dict={ + node_type: data[node_type].tf.col_names_dict + for node_type in data.node_types + }, + node_to_col_stats=col_stats_dict, + stype_encoder_cls_kwargs=DEFAULT_STYPE_ENCODER_DICT, + ) + self.temporal_encoder = HeteroTemporalEncoder( + node_types=[ + node_type for node_type in data.node_types + if "time" in data[node_type] + ], + channels=channels, + ) + self.gnn = HeteroGraphSAGE( + node_types=data.node_types, + edge_types=data.edge_types, + channels=channels, + aggr=aggr, + num_layers=num_layers, + ) + self.head = MLP( + channels, + out_channels=1, + norm=norm, + num_layers=1, + ) + self.lhs_projector = torch.nn.Linear(channels, embedding_dim) + self.id_awareness_emb = torch.nn.Embedding(1, channels) + self.rhs_embedding = torch.nn.Embedding(num_nodes, embedding_dim) + self.lin_offset_idgnn = torch.nn.Linear(embedding_dim, 1) + self.lin_offset_embgnn = torch.nn.Linear(embedding_dim, 1) + + self.rank_topk = rank_topk + self.tr_blocks = torch.nn.ModuleList([ + MultiheadAttentionBlock( + channels=embedding_dim, + heads=heads, + layer_norm=True, + dropout=dropout, + ) for _ in range(1) + ]) + self.channels = channels + + self.reset_parameters() + + def reset_parameters(self) -> None: + self.encoder.reset_parameters() + self.temporal_encoder.reset_parameters() + self.gnn.reset_parameters() + self.head.reset_parameters() + self.id_awareness_emb.reset_parameters() + self.rhs_embedding.reset_parameters() + self.lin_offset_embgnn.reset_parameters() + self.lin_offset_idgnn.reset_parameters() + self.lhs_projector.reset_parameters() + for block in self.tr_blocks: + block.reset_parameters() + + def forward( + self, + batch: HeteroData, + entity_table: NodeType, + dst_table: NodeType, + dst_entity_col: NodeType, + ) -> Tensor: + + seed_time = batch[entity_table].seed_time + x_dict = self.encoder(batch.tf_dict) + + # Add ID-awareness to the root node + x_dict[entity_table][:seed_time.size(0 + )] += self.id_awareness_emb.weight + rel_time_dict = self.temporal_encoder(seed_time, batch.time_dict, + batch.batch_dict) + + for node_type, rel_time in rel_time_dict.items(): + x_dict[node_type] = x_dict[node_type] + rel_time + + x_dict = self.gnn( + x_dict, + batch.edge_index_dict, + ) + + batch_size = seed_time.size(0) + lhs_embedding = x_dict[entity_table][: + batch_size] # batch_size, channel + lhs_embedding_projected = self.lhs_projector(lhs_embedding) + rhs_gnn_embedding = x_dict[dst_table] # num_sampled_rhs, channel + rhs_idgnn_index = batch.n_id_dict[dst_table] # num_sampled_rhs + lhs_idgnn_batch = batch.batch_dict[dst_table] # batch_size + + rhs_embedding = self.rhs_embedding # num_rhs_nodes, channel + embgnn_logits = lhs_embedding_projected @ rhs_embedding.weight.t( + ) # batch_size, num_rhs_nodes + + # Model the importance of embedding-GNN prediction for each lhs node + embgnn_offset_logits = self.lin_offset_embgnn( + lhs_embedding_projected).flatten() + embgnn_logits += embgnn_offset_logits.view(-1, 1) + + # Calculate idgnn logits + idgnn_logits = self.head( + rhs_gnn_embedding).flatten() # num_sampled_rhs + # Because we are only doing 2 hop, we are not really sampling info from + # lhs therefore, we need to incorporate this information using + # lhs_embedding[lhs_idgnn_batch] * rhs_gnn_embedding + idgnn_logits += ( + lhs_embedding[lhs_idgnn_batch] * # num_sampled_rhs, channel + rhs_gnn_embedding).sum( + dim=-1).flatten() # num_sampled_rhs, channel + + # Model the importance of ID-GNN prediction for each lhs node + idgnn_offset_logits = self.lin_offset_idgnn( + lhs_embedding_projected).flatten() + idgnn_logits = idgnn_logits + idgnn_offset_logits[lhs_idgnn_batch] + + embgnn_logits[lhs_idgnn_batch, rhs_idgnn_index] = idgnn_logits + + + + """ + detach the variable + """ + all_rhs_embed = rhs_embedding.weight.detach().clone() #only shallow rhs embeds + assert all_rhs_embed.shape[1] == rhs_gnn_embedding.shape[1], "id GNN embed size should be the same as shallow RHS embed size" + all_rhs_embed[rhs_idgnn_index] = rhs_gnn_embedding.detach().clone() # apply the idGNN embeddings here + + + # all_rhs_embed = rhs_embedding.weight #only shallow rhs embeds + # #! this causes error when the channel size and hidden size is different + # assert all_rhs_embed.shape[1] == rhs_gnn_embedding.shape[1], "id GNN embed size should be the same as shallow RHS embed size" + # all_rhs_embed[rhs_idgnn_index] = rhs_gnn_embedding # apply the idGNN embeddings here + + + transformer_logits, topk_index = self.rerank(embgnn_logits.detach().clone(), all_rhs_embed, lhs_idgnn_batch.detach().clone(), lhs_embedding[lhs_idgnn_batch].detach().clone()) + # transformer_logits = self.rerank(embgnn_logits, all_rhs_embed, lhs_idgnn_batch, lhs_embedding[lhs_idgnn_batch]) + + + + # return embgnn_logits, transformer_logits + return embgnn_logits, transformer_logits, topk_index + + + def rerank(self, gnn_logits, rhs_gnn_embedding, index, lhs_embedding): + """ + reranks the gnn logits based on the provided gnn embeddings. + rhs_gnn_embedding:[# rhs nodes, embed_dim] + """ + topk = self.rank_topk + _, topk_index = torch.topk(gnn_logits, self.rank_topk, dim=1) + embed_size = rhs_gnn_embedding.shape[1] + + # need input batch of size [# nodes, topk, embed_size] + top_embed = torch.stack([rhs_gnn_embedding[topk_index[idx]] for idx in range(topk_index.shape[0])]) + for block in self.tr_blocks: + tr_embed = block(top_embed, top_embed) # [# nodes, topk, embed_size] + + #! for top 50 prediction + # tr_logits = torch.stack([(lhs_embedding[idx] * tr_embed[idx]).sum(dim=-1).flatten() for idx in range(topk_index.shape[0])]) + for idx in range(topk_index.shape[0]): + gnn_logits[idx][topk_index[idx]] = (lhs_embedding[idx] * tr_embed[idx]).sum(dim=-1).flatten() + return gnn_logits, topk_index \ No newline at end of file From 90a8817de472aa773a3f2616ce137314daa386f5 Mon Sep 17 00:00:00 2001 From: shenyang huang Date: Mon, 19 Aug 2024 23:42:22 +0000 Subject: [PATCH 13/26] setting zeros for not used logits in rerank transformer --- benchmark/relbench_link_prediction_benchmark.py | 13 ++++++++++--- hybridgnn/nn/models/rerank_transformer.py | 8 +++----- 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/benchmark/relbench_link_prediction_benchmark.py b/benchmark/relbench_link_prediction_benchmark.py index 715c77f..8bf4ab9 100644 --- a/benchmark/relbench_link_prediction_benchmark.py +++ b/benchmark/relbench_link_prediction_benchmark.py @@ -147,7 +147,7 @@ "embedding_dim": [64], "norm": ["layer_norm"], "dropout": [0.0, 0.1, 0.2], - "rank_topk": [25,50,100] + "rank_topk": [25,50,100, 200] } train_search_space = { "batch_size": [128, 256, 512], @@ -255,10 +255,17 @@ def test(model: torch.nn.Module, loader: NeighborLoader, stage: str) -> float: task.dst_entity_col).detach() scores = torch.sigmoid(out) elif args.model in ["rerank_transformer"]: - _, out, _ = model(batch, task.src_entity_table, + gnn_logits, tr_logits, topk_index = model(batch, task.src_entity_table, task.dst_entity_table, task.dst_entity_col) - scores = torch.sigmoid(out.detach()) + gnn_logits, tr_logits, topk_index = gnn_logits.detach(), tr_logits.detach(), topk_index.detach() + for idx in range(topk_index.shape[0]): + gnn_logits[idx][topk_index[idx]] = tr_logits[idx][topk_index[idx]] + + scores = torch.sigmoid(gnn_logits) + #scores = torch.sigmoid(out.detach()) + + else: raise ValueError(f"Unsupported model type: {args.model}.") diff --git a/hybridgnn/nn/models/rerank_transformer.py b/hybridgnn/nn/models/rerank_transformer.py index e2760fc..a62d5bf 100644 --- a/hybridgnn/nn/models/rerank_transformer.py +++ b/hybridgnn/nn/models/rerank_transformer.py @@ -189,9 +189,6 @@ def forward( transformer_logits, topk_index = self.rerank(embgnn_logits.detach().clone(), all_rhs_embed, lhs_idgnn_batch.detach().clone(), lhs_embedding[lhs_idgnn_batch].detach().clone()) # transformer_logits = self.rerank(embgnn_logits, all_rhs_embed, lhs_idgnn_batch, lhs_embedding[lhs_idgnn_batch]) - - - # return embgnn_logits, transformer_logits return embgnn_logits, transformer_logits, topk_index @@ -210,7 +207,8 @@ def rerank(self, gnn_logits, rhs_gnn_embedding, index, lhs_embedding): tr_embed = block(top_embed, top_embed) # [# nodes, topk, embed_size] #! for top 50 prediction + out_logits = torch.zeros(gnn_logits.shape).to(gnn_logits.device) # tr_logits = torch.stack([(lhs_embedding[idx] * tr_embed[idx]).sum(dim=-1).flatten() for idx in range(topk_index.shape[0])]) for idx in range(topk_index.shape[0]): - gnn_logits[idx][topk_index[idx]] = (lhs_embedding[idx] * tr_embed[idx]).sum(dim=-1).flatten() - return gnn_logits, topk_index \ No newline at end of file + out_logits[idx][topk_index[idx]] = (lhs_embedding[idx] * tr_embed[idx]).sum(dim=-1).flatten() + return out_logits, topk_index \ No newline at end of file From 1eb769cdd4cf75a3ad732209ee385cdcc6516aee Mon Sep 17 00:00:00 2001 From: shenyang huang Date: Wed, 21 Aug 2024 18:44:50 +0000 Subject: [PATCH 14/26] adding reranker transformer --- .../relbench_link_prediction_benchmark.py | 27 ++++++--- hybridgnn/nn/models/rerank_transformer.py | 59 ++++++++++++++----- 2 files changed, 63 insertions(+), 23 deletions(-) diff --git a/benchmark/relbench_link_prediction_benchmark.py b/benchmark/relbench_link_prediction_benchmark.py index 8bf4ab9..3b5756d 100644 --- a/benchmark/relbench_link_prediction_benchmark.py +++ b/benchmark/relbench_link_prediction_benchmark.py @@ -32,6 +32,9 @@ from hybridgnn.nn.models import IDGNN, HybridGNN, ShallowRHSGNN, Hybrid_RHSTransformer, ReRankTransformer from hybridgnn.utils import GloveTextEmbedding +from torch_geometric.utils.map import map_index + + TRAIN_CONFIG_KEYS = ["batch_size", "gamma_rate", "base_lr"] LINK_PREDICTION_METRIC = "link_prediction_map" @@ -146,8 +149,8 @@ "channels": [64], "embedding_dim": [64], "norm": ["layer_norm"], - "dropout": [0.0, 0.1, 0.2], - "rank_topk": [25,50,100, 200] + "dropout": [0.1, 0.2], + "rank_topk": [100] } train_search_space = { "batch_size": [128, 256, 512], @@ -205,6 +208,19 @@ def train( gnn_logits, tr_logits, topk_idx = model(batch, task.src_entity_table, task.dst_entity_table, task.dst_entity_col) edge_label_index = torch.stack([src_batch, dst_index], dim=0) loss = sparse_cross_entropy(gnn_logits, edge_label_index) + + #! continue here to debug for map_index to only get label for the topk that transformer learns + """ + # batch_size = batch[task.src_entity_table].batch_size + # target = torch.isin( + # batch[task.dst_entity_table].batch + + # batch_size * batch[task.dst_entity_table].n_id, + # src_batch + batch_size * dst_index, + # ).float() + # print (target.shape) + # quit() + # topk_labels = map_index(edge_label_index, topk_idx) + """ loss += sparse_cross_entropy(tr_logits, edge_label_index) numel = len(batch[task.dst_entity_table].batch) @@ -258,12 +274,7 @@ def test(model: torch.nn.Module, loader: NeighborLoader, stage: str) -> float: gnn_logits, tr_logits, topk_index = model(batch, task.src_entity_table, task.dst_entity_table, task.dst_entity_col) - gnn_logits, tr_logits, topk_index = gnn_logits.detach(), tr_logits.detach(), topk_index.detach() - for idx in range(topk_index.shape[0]): - gnn_logits[idx][topk_index[idx]] = tr_logits[idx][topk_index[idx]] - - scores = torch.sigmoid(gnn_logits) - #scores = torch.sigmoid(out.detach()) + scores = torch.sigmoid(tr_logits.detach()) else: diff --git a/hybridgnn/nn/models/rerank_transformer.py b/hybridgnn/nn/models/rerank_transformer.py index a62d5bf..e1fb073 100644 --- a/hybridgnn/nn/models/rerank_transformer.py +++ b/hybridgnn/nn/models/rerank_transformer.py @@ -16,6 +16,7 @@ from torch_scatter import scatter_max from torch_geometric.nn.aggr.utils import MultiheadAttentionBlock from torch_geometric.utils import to_dense_batch +from torch_geometric.utils.map import map_index @@ -93,6 +94,8 @@ def __init__( dropout=dropout, ) for _ in range(1) ]) + # self.tr_lin = torch.nn.Linear(embedding_dim*2, embedding_dim) + self.channels = channels self.reset_parameters() @@ -109,6 +112,7 @@ def reset_parameters(self) -> None: self.lhs_projector.reset_parameters() for block in self.tr_blocks: block.reset_parameters() + # self.tr_lin.reset_parameters() def forward( self, @@ -170,27 +174,52 @@ def forward( embgnn_logits[lhs_idgnn_batch, rhs_idgnn_index] = idgnn_logits - - - """ - detach the variable - """ - all_rhs_embed = rhs_embedding.weight.detach().clone() #only shallow rhs embeds + #! let's do end to end transformer here + all_rhs_embed = rhs_embedding.weight #only shallow rhs embeds assert all_rhs_embed.shape[1] == rhs_gnn_embedding.shape[1], "id GNN embed size should be the same as shallow RHS embed size" - all_rhs_embed[rhs_idgnn_index] = rhs_gnn_embedding.detach().clone() # apply the idGNN embeddings here - - # all_rhs_embed = rhs_embedding.weight #only shallow rhs embeds - # #! this causes error when the channel size and hidden size is different - # assert all_rhs_embed.shape[1] == rhs_gnn_embedding.shape[1], "id GNN embed size should be the same as shallow RHS embed size" + #* rhs_gnn_embedding is significantly smaller than rhs_embed and we can't use inplace operation during backprop + #* -----> this is not global, can't replace like this + copy_tensor = torch.zeros(all_rhs_embed.shape).to(all_rhs_embed.device) + copy_tensor[rhs_idgnn_index] = rhs_gnn_embedding + final_rhs_embed = all_rhs_embed + copy_tensor # all_rhs_embed[rhs_idgnn_index] = rhs_gnn_embedding # apply the idGNN embeddings here + # transformer_logits, topk_index = self.rerank(embgnn_logits.detach().clone(), final_rhs_embed, lhs_idgnn_batch.detach().clone(), lhs_embedding[lhs_idgnn_batch].detach().clone()) + transformer_logits, topk_index = self.rerank(embgnn_logits, final_rhs_embed, lhs_idgnn_batch, lhs_embedding_projected[lhs_idgnn_batch]) - transformer_logits, topk_index = self.rerank(embgnn_logits.detach().clone(), all_rhs_embed, lhs_idgnn_batch.detach().clone(), lhs_embedding[lhs_idgnn_batch].detach().clone()) - # transformer_logits = self.rerank(embgnn_logits, all_rhs_embed, lhs_idgnn_batch, lhs_embedding[lhs_idgnn_batch]) return embgnn_logits, transformer_logits, topk_index + #* adding lhs embedding code not working yet + # def rerank(self, gnn_logits, rhs_gnn_embedding, index, lhs_embedding): + # """ + # reranks the gnn logits based on the provided gnn embeddings. + # rhs_gnn_embedding:[# rhs nodes, embed_dim] + # """ + # topk = self.rank_topk + # _, topk_index = torch.topk(gnn_logits, self.rank_topk, dim=1) + # embed_size = rhs_gnn_embedding.shape[1] + + # # need input batch of size [# nodes, topk, embed_size] + # #! concatenate the lhs embedding with rhs embedding + # top_embed = torch.stack([torch.cat((rhs_gnn_embedding[topk_index[idx]],lhs_embedding[idx].view(1,-1).expand(self.rank_topk,-1)), dim=1) for idx in range(topk_index.shape[0])]) + # tr_embed = top_embed + # for block in self.tr_blocks: + # tr_embed = block(tr_embed, tr_embed) # [# nodes, topk, embed_size] + + # tr_embed = tr_embed.view(-1,embed_size*2) + # tr_embed = self.tr_lin(tr_embed) + # tr_embed = tr_embed.view(-1,self.rank_topk,embed_size) + + + # #! for top k prediction + # out_logits = torch.full(gnn_logits.shape, -float('inf')).to(gnn_logits.device) + # # tr_logits = torch.stack([(lhs_embedding[idx] * tr_embed[idx]).sum(dim=-1).flatten() for idx in range(topk_index.shape[0])]) + # for idx in range(topk_index.shape[0]): + # out_logits[idx][topk_index[idx]] = (lhs_embedding[idx] * tr_embed[idx]).sum(dim=-1).flatten() + # return out_logits, topk_index + def rerank(self, gnn_logits, rhs_gnn_embedding, index, lhs_embedding): """ @@ -206,8 +235,8 @@ def rerank(self, gnn_logits, rhs_gnn_embedding, index, lhs_embedding): for block in self.tr_blocks: tr_embed = block(top_embed, top_embed) # [# nodes, topk, embed_size] - #! for top 50 prediction - out_logits = torch.zeros(gnn_logits.shape).to(gnn_logits.device) + #! for top k prediction + out_logits = torch.full(gnn_logits.shape, -float('inf')).to(gnn_logits.device) # tr_logits = torch.stack([(lhs_embedding[idx] * tr_embed[idx]).sum(dim=-1).flatten() for idx in range(topk_index.shape[0])]) for idx in range(topk_index.shape[0]): out_logits[idx][topk_index[idx]] = (lhs_embedding[idx] * tr_embed[idx]).sum(dim=-1).flatten() From 75445cee5618d35974b79b5c32b9e298b747eab3 Mon Sep 17 00:00:00 2001 From: shenyang huang Date: Wed, 21 Aug 2024 23:23:38 +0000 Subject: [PATCH 15/26] updating RHS transformer code --- .../relbench_link_prediction_benchmark.py | 27 ++++--- hybridgnn/nn/encoder.py | 42 +++++++---- ...id_rhstransformer.py => RHSTransformer.py} | 75 +++++-------------- hybridgnn/nn/models/__init__.py | 4 +- hybridgnn/nn/models/transformer.py | 39 +--------- 5 files changed, 66 insertions(+), 121 deletions(-) rename hybridgnn/nn/models/{hybrid_rhstransformer.py => RHSTransformer.py} (73%) diff --git a/benchmark/relbench_link_prediction_benchmark.py b/benchmark/relbench_link_prediction_benchmark.py index a375492..6ec54be 100644 --- a/benchmark/relbench_link_prediction_benchmark.py +++ b/benchmark/relbench_link_prediction_benchmark.py @@ -1,3 +1,7 @@ +""" +$ python relbench_link_prediction_benchmark.py --dataset rel-stack --task post-post-related --model rhstransformer --num_trials 10 +""" + import argparse import json import os @@ -30,7 +34,7 @@ from torch_geometric.utils.cross_entropy import sparse_cross_entropy from tqdm import tqdm -from hybridgnn.nn.models import IDGNN, HybridGNN, ShallowRHSGNN, Hybrid_RHSTransformer, ReRankTransformer +from hybridgnn.nn.models import IDGNN, HybridGNN, ShallowRHSGNN, RHSTransformer, ReRankTransformer from hybridgnn.utils import GloveTextEmbedding from torch_geometric.utils.map import map_index @@ -105,7 +109,7 @@ int(args.num_neighbors // 2**i) for i in range(args.num_layers) ] -model_cls: Type[Union[IDGNN, HybridGNN, ShallowRHSGNN, Hybrid_RHSTransformer, ReRankTransformer]] +model_cls: Type[Union[IDGNN, HybridGNN, ShallowRHSGNN, RHSTransformer, ReRankTransformer]] if args.model == "idgnn": model_search_space = { @@ -136,23 +140,27 @@ model_cls = (HybridGNN if args.model == "hybridgnn" else ShallowRHSGNN) elif args.model in ["rhstransformer"]: model_search_space = { + "encoder_channels": [64, 128], + "encoder_layers": [2, 4], "channels": [64, 128], "embedding_dim": [64, 128], - "norm": ["layer_norm"], + "norm": ["layer_norm", "batch_norm"], "dropout": [0.1, 0.2], - "pe": [None], + "t_encoding_type": ["fuse", "absolute"], } train_search_space = { - "batch_size": [128, 256, 512], + "batch_size": [128, 256], "base_lr": [0.0005, 0.01], "gamma_rate": [0.9, 1.0], } - model_cls = Hybrid_RHSTransformer + model_cls = RHSTransformer elif args.model in ["rerank_transformer"]: model_search_space = { + "encoder_channels": [64, 128, 256], + "encoder_layers": [2, 4, 8], "channels": [64], "embedding_dim": [64], - "norm": ["layer_norm"], + "norm": ["layer_norm", "batch_norm"], "dropout": [0.1, 0.2], "rank_topk": [100] } @@ -204,7 +212,7 @@ def train( loss = sparse_cross_entropy(logits, edge_label_index) numel = len(batch[task.dst_entity_table].batch) elif args.model in ["rhstransformer"]: - logits = model(batch, task.src_entity_table, task.dst_entity_table, task.dst_entity_col) + logits = model(batch, task.src_entity_table, task.dst_entity_table) edge_label_index = torch.stack([src_batch, dst_index], dim=0) loss = sparse_cross_entropy(logits, edge_label_index) numel = len(batch[task.dst_entity_table].batch) @@ -271,8 +279,7 @@ def test(model: torch.nn.Module, loader: NeighborLoader, stage: str) -> float: scores = torch.sigmoid(out) elif args.model in ["rhstransformer"]: out = model(batch, task.src_entity_table, - task.dst_entity_table, - task.dst_entity_col).detach() + task.dst_entity_table).detach() scores = torch.sigmoid(out) elif args.model in ["rerank_transformer"]: gnn_logits, tr_logits, topk_index = model(batch, task.src_entity_table, diff --git a/hybridgnn/nn/encoder.py b/hybridgnn/nn/encoder.py index 9887fc1..2a8c933 100644 --- a/hybridgnn/nn/encoder.py +++ b/hybridgnn/nn/encoder.py @@ -95,9 +95,16 @@ def forward( class HeteroTemporalEncoder(torch.nn.Module): - def __init__(self, node_types: List[NodeType], channels: int) -> None: + def __init__(self, node_types: List[NodeType], channels: int, + encoding_type: Optional[str] = "absolute",) -> None: + r""" + temporal encoder: + encoding_type (str, optional): which type of time encoding to use, options are ["absolute", "learnable", "fuse"] + """ super().__init__() + self.encoding_type = encoding_type # ["absolute", "fuse"] + self.encoder_dict = torch.nn.ModuleDict({ node_type: PositionalEncoding(channels) @@ -109,15 +116,19 @@ def __init__(self, node_types: List[NodeType], channels: int) -> None: for node_type in node_types }) - time_dim = 3 # hour, day, week - self.time_fuser = torch.nn.Linear(time_dim, channels) + if (self.encoding_type == "fuse"): + time_dim = 3 # hour, day, week + self.time_fuser = torch.nn.Linear(time_dim, channels) def reset_parameters(self) -> None: for encoder in self.encoder_dict.values(): encoder.reset_parameters() for lin in self.lin_dict.values(): lin.reset_parameters() - self.time_fuser.reset_parameters() + if (self.encoding_type == "learnable"): + self.day_pe.reset_parameters() + elif (self.encoding_type == "fuse"): + self.time_fuser.reset_parameters() def forward( self, @@ -129,18 +140,17 @@ def forward( for node_type, time in time_dict.items(): rel_time = seed_time[batch_dict[node_type]] - time - - # rel_day = rel_time / SECONDS_IN_A_DAY - # x = self.encoder_dict[node_type](rel_day) - # x = self.encoder_dict[node_type](rel_hour) - rel_hour = (rel_time // SECONDS_IN_A_HOUR).view(-1,1) - rel_day = (rel_time // SECONDS_IN_A_DAY).view(-1,1) - rel_week = (rel_time // SECONDS_IN_A_WEEK).view(-1,1) - time_embed = torch.cat((rel_hour, rel_day, rel_week),dim=1).float() - - #! might need to normalize hour, day, week into the same scale - time_embed = torch.nn.functional.normalize(time_embed, p=2.0, dim=1) - x = self.time_fuser(time_embed) + + if (self.encoding_type == "absolute"): + rel_time = rel_time / SECONDS_IN_A_DAY + x = self.encoder_dict[node_type](rel_time) + elif (self.encoding_type == "fuse"): + rel_hour = (rel_time // SECONDS_IN_A_HOUR).view(-1,1) + rel_day = (rel_time // SECONDS_IN_A_DAY).view(-1,1) + rel_week = (rel_time // SECONDS_IN_A_WEEK).view(-1,1) + time_embed = torch.cat((rel_hour, rel_day, rel_week),dim=1).float() + time_embed = torch.nn.functional.normalize(time_embed, p=2.0, dim=1) #normalize hour, day, week into same scale + x = self.time_fuser(time_embed) x = self.lin_dict[node_type](x) out_dict[node_type] = x diff --git a/hybridgnn/nn/models/hybrid_rhstransformer.py b/hybridgnn/nn/models/RHSTransformer.py similarity index 73% rename from hybridgnn/nn/models/hybrid_rhstransformer.py rename to hybridgnn/nn/models/RHSTransformer.py index e373ef4..585d337 100644 --- a/hybridgnn/nn/models/hybrid_rhstransformer.py +++ b/hybridgnn/nn/models/RHSTransformer.py @@ -1,8 +1,9 @@ -from typing import Any, Dict +from typing import Any, Dict, Optional, Type import torch from torch import Tensor from torch_frame.data.stats import StatType +from torch_frame.nn.models import ResNet from torch_geometric.data import HeteroData from torch_geometric.nn import MLP from torch_geometric.typing import NodeType @@ -13,24 +14,11 @@ HeteroTemporalEncoder, ) from hybridgnn.nn.models import HeteroGraphSAGE -from hybridgnn.nn.models.transformer import RHSTransformer -from torch_scatter import scatter_max - - -class Hybrid_RHSTransformer(torch.nn.Module): - r"""Implementation of RHSTransformer model. - Args: - data (HeteroData): dataset - col_stats_dict (Dict[str, Dict[str, Dict[StatType, Any]]]): column stats - num_nodes (int): number of nodes, - num_layers (int): number of mp layers, - channels (int): input dimension, - embedding_dim (int): embedding dimension size, - aggr (str): aggregation type, - norm (norm): normalization type, - dropout (float): dropout rate for the transformer float, - heads (int): number of attention heads, - pe (str): type of positional encoding for the transformer,""" +from hybridgnn.nn.models.transformer import Transformer + + +class RHSTransformer(torch.nn.Module): + r"""Implementation of RHSTransformer model.""" def __init__( self, data: HeteroData, @@ -43,7 +31,9 @@ def __init__( norm: str = 'layer_norm', dropout: float = 0.2, heads: int = 1, - pe: str = "abs", + t_encoding_type: str = "absolute", + torch_frame_model_cls: Type[torch.nn.Module] = ResNet, + torch_frame_model_kwargs: Optional[Dict[str, Any]] = None, ) -> None: super().__init__() @@ -55,6 +45,8 @@ def __init__( }, node_to_col_stats=col_stats_dict, stype_encoder_cls_kwargs=DEFAULT_STYPE_ENCODER_DICT, + torch_frame_model_cls=torch_frame_model_cls, + torch_frame_model_kwargs=torch_frame_model_kwargs, ) self.temporal_encoder = HeteroTemporalEncoder( node_types=[ @@ -62,6 +54,7 @@ def __init__( if "time" in data[node_type] ], channels=channels, + encoding_type=t_encoding_type, ) self.gnn = HeteroGraphSAGE( node_types=data.node_types, @@ -81,14 +74,12 @@ def __init__( self.rhs_embedding = torch.nn.Embedding(num_nodes, embedding_dim) self.lin_offset_idgnn = torch.nn.Linear(embedding_dim, 1) self.lin_offset_embgnn = torch.nn.Linear(embedding_dim, 1) + self.channels = channels - self.rhs_transformer = RHSTransformer(in_channels=channels, + self.rhs_transformer = Transformer(in_channels=channels, out_channels=channels, hidden_channels=channels, - heads=heads, dropout=dropout, - position_encoding=pe) - - self.channels = channels + heads=heads, dropout=dropout) self.reset_parameters() @@ -109,7 +100,6 @@ def forward( batch: HeteroData, entity_table: NodeType, dst_table: NodeType, - dst_entity_col: NodeType, ) -> Tensor: seed_time = batch[entity_table].seed_time x_dict = self.encoder(batch.tf_dict) @@ -135,15 +125,8 @@ def forward( rhs_gnn_embedding = x_dict[dst_table] # num_sampled_rhs, channel rhs_idgnn_index = batch.n_id_dict[dst_table] # num_sampled_rhs lhs_idgnn_batch = batch.batch_dict[dst_table] # batch_size - - #! need custom code to work for specific datasets - # rhs_time = self.get_rhs_time_dict(batch.time_dict, batch.edge_index_dict, batch[entity_table].seed_time, batch, dst_entity_col, dst_table) - - # adding rhs transformer - rhs_gnn_embedding = self.rhs_transformer(rhs_gnn_embedding, - lhs_idgnn_batch, batch_size=batch_size) - rhs_embedding = self.rhs_embedding # num_rhs_nodes, channel + embgnn_logits = lhs_embedding_projected @ rhs_embedding.weight.t( ) # batch_size, num_rhs_nodes @@ -152,6 +135,9 @@ def forward( lhs_embedding_projected).flatten() embgnn_logits += embgnn_offset_logits.view(-1, 1) + #* transformer forward pass + rhs_gnn_embedding = self.rhs_transformer(rhs_gnn_embedding, + lhs_idgnn_batch, batch_size=batch_size) # Calculate idgnn logits idgnn_logits = self.head( rhs_gnn_embedding).flatten() # num_sampled_rhs @@ -170,24 +156,3 @@ def forward( embgnn_logits[lhs_idgnn_batch, rhs_idgnn_index] = idgnn_logits return embgnn_logits - - def get_rhs_time_dict( - self, - time_dict, - edge_index_dict, - seed_time, - batch_dict, - dst_entity_col, - dst_entity_table, - ): - #* what to put when transaction table is merged - edge_index = edge_index_dict['sponsors','f2p_sponsor_id', - 'sponsors_studies'] - rhs_time, _ = scatter_max( - time_dict['sponsors'][edge_index[0]], - edge_index[1]) - SECONDS_IN_A_DAY = 60 * 60 * 24 - NANOSECONDS_IN_A_DAY = 60 * 60 * 24 * 1000000000 - rhs_rel_time = seed_time[batch_dict[dst_entity_col]] - rhs_time - rhs_rel_time = rhs_rel_time / NANOSECONDS_IN_A_DAY - return rhs_rel_time diff --git a/hybridgnn/nn/models/__init__.py b/hybridgnn/nn/models/__init__.py index 5d9f3cf..7b6c7e8 100644 --- a/hybridgnn/nn/models/__init__.py +++ b/hybridgnn/nn/models/__init__.py @@ -2,10 +2,10 @@ from .idgnn import IDGNN from .hybridgnn import HybridGNN from .shallowrhsgnn import ShallowRHSGNN -from .hybrid_rhstransformer import Hybrid_RHSTransformer +from .RHSTransformer import RHSTransformer from .rerank_transformer import ReRankTransformer __all__ = classes = [ 'HeteroGraphSAGE', 'IDGNN', 'HybridGNN', 'ShallowRHSGNN', - 'Hybrid_RHSTransformer', 'ReRankTransformer' + 'RHSTransformer', 'ReRankTransformer' ] diff --git a/hybridgnn/nn/models/transformer.py b/hybridgnn/nn/models/transformer.py index a9cf2ac..f820802 100644 --- a/hybridgnn/nn/models/transformer.py +++ b/hybridgnn/nn/models/transformer.py @@ -10,7 +10,7 @@ from torch_geometric.nn.encoding import PositionalEncoding -class RHSTransformer(torch.nn.Module): +class Transformer(torch.nn.Module): r"""A module to attend to rhs embeddings with a transformer. Args: in_channels (int): The number of input channels of the RHS embedding. @@ -29,7 +29,6 @@ def __init__( heads: int = 1, num_transformer_blocks: int = 1, dropout: float = 0.0, - position_encoding: str = "abs", ) -> None: super().__init__() @@ -38,15 +37,6 @@ def __init__( self.hidden_channels = hidden_channels self.lin = torch.nn.Linear(in_channels, hidden_channels) self.fc = torch.nn.Linear(hidden_channels, out_channels) - self.pe_type = position_encoding - self.pe = None - if (position_encoding == "abs"): - self.pe = PositionalEncoding(hidden_channels) - elif (position_encoding is None): - self.pe = None - else: - raise NotImplementedError - self.blocks = torch.nn.ModuleList([ MultiheadAttentionBlock( channels=hidden_channels, @@ -68,10 +58,6 @@ def forward(self, rhs_embed: Tensor, index: Tensor, batch_size) -> Tensor: """ rhs_embed = self.lin(rhs_embed) - if (self.pe_type == "abs"): - rhs_embed = rhs_embed + self.pe( - torch.arange(rhs_embed.size(0), device=rhs_embed.device)) - # #! if we sort the index, we need to sort the rhs_embed sorted_index, sorted_idx = torch.sort(index, stable=True) index = index[sorted_idx] @@ -91,26 +77,3 @@ def inverse_permutation(self,perm): inv[perm] = torch.arange(perm.size(0), device=perm.device) return inv - -class RotaryPositionalEmbeddings(torch.nn.Module): - def __init__(self, channels, base=10000): - super().__init__() - self.channels = channels - self.base = base - self.inv_freq = 1. / (base**(torch.arange(0, channels, 2).float() / - channels)) - - def forward(self, x, pos=None): - seq_len = x.shape[1] - if (pos is None): - pos = torch.arange(seq_len, device=x.device).type_as(self.inv_freq) - freqs = torch.einsum('i,j->ij', pos, self.inv_freq) - emb = torch.cat((freqs, freqs), dim=-1) - - cos = emb.cos().to(x.device) - sin = emb.sin().to(x.device) - - x1, x2 = x[..., ::2], x[..., 1::2] - rotated = torch.stack([-x2, x1], dim=-1).reshape(x.shape).to(x.device) - - return x * cos + rotated * sin From 7380a1709f0fda2aafd4d039a9130b96e309233a Mon Sep 17 00:00:00 2001 From: shenyang huang Date: Fri, 23 Aug 2024 19:30:49 +0000 Subject: [PATCH 16/26] updating rerank_transformer --- .../relbench_link_prediction_benchmark.py | 42 ++++---- hybridgnn/nn/models/rerank_transformer.py | 98 ++++++++++++------- 2 files changed, 83 insertions(+), 57 deletions(-) diff --git a/benchmark/relbench_link_prediction_benchmark.py b/benchmark/relbench_link_prediction_benchmark.py index 6ec54be..30042eb 100644 --- a/benchmark/relbench_link_prediction_benchmark.py +++ b/benchmark/relbench_link_prediction_benchmark.py @@ -36,6 +36,7 @@ from hybridgnn.nn.models import IDGNN, HybridGNN, ShallowRHSGNN, RHSTransformer, ReRankTransformer from hybridgnn.utils import GloveTextEmbedding +from torch_geometric.utils import index_to_mask from torch_geometric.utils.map import map_index @@ -71,7 +72,8 @@ parser.add_argument("--result_path", type=str, default="result") args = parser.parse_args() -device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +# device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +device = "cpu" if torch.cuda.is_available(): torch.set_num_threads(1) seed_everything(args.seed) @@ -220,22 +222,21 @@ def train( gnn_logits, tr_logits, topk_idx = model(batch, task.src_entity_table, task.dst_entity_table, task.dst_entity_col) edge_label_index = torch.stack([src_batch, dst_index], dim=0) loss = sparse_cross_entropy(gnn_logits, edge_label_index) - - #! continue here to debug for map_index to only get label for the topk that transformer learns - """ - # batch_size = batch[task.src_entity_table].batch_size - # target = torch.isin( - # batch[task.dst_entity_table].batch + - # batch_size * batch[task.dst_entity_table].n_id, - # src_batch + batch_size * dst_index, - # ).float() - # print (target.shape) - # quit() - # topk_labels = map_index(edge_label_index, topk_idx) - """ - loss += sparse_cross_entropy(tr_logits, edge_label_index) + num_rhs_nodes = gnn_logits.shape[1] + + #* tr_logits: [batch_size, topk], we need to get the edges that exist in the topk prediction + batch_size = topk_idx.shape[0] + topk = topk_idx.shape[1] + idx_position = (torch.arange(batch_size) * num_rhs_nodes).view(-1,1).to(tr_logits.device) + topk_idx = topk_idx + idx_position + + correct_label = torch.isin(topk_idx,src_batch * num_rhs_nodes + dst_index).float() + loss += F.binary_cross_entropy_with_logits(tr_logits, correct_label) + # true_label_index, mask = map_index(topk_idx, edge_label_index) + # correct_label = torch.zeros(tr_logits.shape).to(tr_logits.device) + # correct_label[mask] = True + # loss += sparse_cross_entropy(tr_logits, edge_label_index) numel = len(batch[task.dst_entity_table].batch) - loss.backward() optimizer.step() @@ -282,11 +283,14 @@ def test(model: torch.nn.Module, loader: NeighborLoader, stage: str) -> float: task.dst_entity_table).detach() scores = torch.sigmoid(out) elif args.model in ["rerank_transformer"]: - gnn_logits, tr_logits, topk_index = model(batch, task.src_entity_table, + gnn_logits, tr_logits, topk_idx = model(batch, task.src_entity_table, task.dst_entity_table, task.dst_entity_col) - scores = torch.sigmoid(tr_logits.detach()) - + #! need to change the shape of tr_logits + scores = torch.zeros(batch_size, task.num_dst_nodes, + device=tr_logits.device) + scores.scatter_(1, topk_idx, torch.sigmoid(tr_logits.detach())) + # scores[topk_index] = torch.sigmoid(tr_logits.detach().flatten()) else: raise ValueError(f"Unsupported model type: {args.model}.") diff --git a/hybridgnn/nn/models/rerank_transformer.py b/hybridgnn/nn/models/rerank_transformer.py index e1fb073..28ef133 100644 --- a/hybridgnn/nn/models/rerank_transformer.py +++ b/hybridgnn/nn/models/rerank_transformer.py @@ -1,9 +1,10 @@ -from typing import Any, Dict +from typing import Any, Dict, Optional, Type import torch from torch import Tensor from torch_frame.data.stats import StatType from torch_geometric.data import HeteroData +from torch_frame.nn.models import ResNet from torch_geometric.nn import MLP from torch_geometric.typing import NodeType @@ -47,6 +48,8 @@ def __init__( dropout: float = 0.2, heads: int = 1, rank_topk: int = 100, + torch_frame_model_cls: Type[torch.nn.Module] = ResNet, + torch_frame_model_kwargs: Optional[Dict[str, Any]] = None, ) -> None: super().__init__() @@ -58,6 +61,8 @@ def __init__( }, node_to_col_stats=col_stats_dict, stype_encoder_cls_kwargs=DEFAULT_STYPE_ENCODER_DICT, + torch_frame_model_cls=torch_frame_model_cls, + torch_frame_model_kwargs=torch_frame_model_kwargs, ) self.temporal_encoder = HeteroTemporalEncoder( node_types=[ @@ -88,13 +93,13 @@ def __init__( self.rank_topk = rank_topk self.tr_blocks = torch.nn.ModuleList([ MultiheadAttentionBlock( - channels=embedding_dim, + channels=embedding_dim*2, heads=heads, layer_norm=True, dropout=dropout, ) for _ in range(1) ]) - # self.tr_lin = torch.nn.Linear(embedding_dim*2, embedding_dim) + self.tr_lin = torch.nn.Linear(embedding_dim*2, 1) self.channels = channels @@ -112,7 +117,7 @@ def reset_parameters(self) -> None: self.lhs_projector.reset_parameters() for block in self.tr_blocks: block.reset_parameters() - # self.tr_lin.reset_parameters() + self.tr_lin.reset_parameters() def forward( self, @@ -159,6 +164,7 @@ def forward( # Calculate idgnn logits idgnn_logits = self.head( rhs_gnn_embedding).flatten() # num_sampled_rhs + # Because we are only doing 2 hop, we are not really sampling info from # lhs therefore, we need to incorporate this information using # lhs_embedding[lhs_idgnn_batch] * rhs_gnn_embedding @@ -174,22 +180,60 @@ def forward( embgnn_logits[lhs_idgnn_batch, rhs_idgnn_index] = idgnn_logits - #! let's do end to end transformer here - all_rhs_embed = rhs_embedding.weight #only shallow rhs embeds - assert all_rhs_embed.shape[1] == rhs_gnn_embedding.shape[1], "id GNN embed size should be the same as shallow RHS embed size" + shallow_rhs_embed = rhs_embedding.weight + transformer_logits, topk_index = self.rerank(embgnn_logits, shallow_rhs_embed, rhs_gnn_embedding, rhs_idgnn_index, idgnn_logits, lhs_idgnn_batch,lhs_embedding_projected[lhs_idgnn_batch]) + return embgnn_logits, transformer_logits, topk_index - #* rhs_gnn_embedding is significantly smaller than rhs_embed and we can't use inplace operation during backprop - #* -----> this is not global, can't replace like this - copy_tensor = torch.zeros(all_rhs_embed.shape).to(all_rhs_embed.device) - copy_tensor[rhs_idgnn_index] = rhs_gnn_embedding - final_rhs_embed = all_rhs_embed + copy_tensor - # all_rhs_embed[rhs_idgnn_index] = rhs_gnn_embedding # apply the idGNN embeddings here - # transformer_logits, topk_index = self.rerank(embgnn_logits.detach().clone(), final_rhs_embed, lhs_idgnn_batch.detach().clone(), lhs_embedding[lhs_idgnn_batch].detach().clone()) - transformer_logits, topk_index = self.rerank(embgnn_logits, final_rhs_embed, lhs_idgnn_batch, lhs_embedding_projected[lhs_idgnn_batch]) + def rerank(self, gnn_logits, shallow_rhs_embed, rhs_idgnn_embed, rhs_idgnn_index, idgnn_logits, lhs_idgnn_batch, lhs_embedding): + """ + reranks the gnn logits based on the provided gnn embeddings. + shallow_rhs_embed:[# rhs nodes, embed_dim] + + """ + embed_size = rhs_idgnn_embed.shape[1] + batch_size = gnn_logits.shape[0] + num_rhs_nodes = shallow_rhs_embed.shape[0] + + filtered_logits, topk_indices = torch.topk(gnn_logits, self.rank_topk, dim=1) + # [batch_size, topk, embed_size] + seq = shallow_rhs_embed[topk_indices.flatten()].view(batch_size * self.rank_topk, embed_size) + rhs_idgnn_index = lhs_idgnn_batch * num_rhs_nodes + rhs_idgnn_index + + query_rhs_idgnn_index, mask = map_index(topk_indices.view(-1), rhs_idgnn_index) + id_gnn_seq = torch.zeros(batch_size * self.rank_topk, embed_size) + id_gnn_seq[mask] = rhs_idgnn_embed[query_rhs_idgnn_index] + + logit_mask = torch.zeros(batch_size * self.rank_topk, embed_size, dtype=bool) + logit_mask[mask] = True + seq = torch.where(logit_mask, id_gnn_seq.view(-1,embed_size), seq.view(-1,embed_size)) + + unique_lhs_idx = torch.unique(lhs_idgnn_batch) + lhs_uniq_embed = lhs_embedding[unique_lhs_idx] + + seq = seq.clone() + seq = seq.view(batch_size,self.rank_topk,-1) + + lhs_uniq_embed = lhs_uniq_embed.view(-1,1,embed_size) + lhs_uniq_embed = lhs_uniq_embed.expand(-1,seq.shape[1],-1) + seq = torch.cat((seq,lhs_uniq_embed), dim=-1) + + for block in self.tr_blocks: + seq = block(seq, seq) # [# nodes, topk, embed_size] + + #! just get the logit directly from transformer + seq = seq.view(-1,embed_size*2) + seq = self.tr_lin(seq) + topk_logits = seq.view(batch_size,self.rank_topk) + + _, topk_indices = torch.topk(gnn_logits, self.rank_topk, dim=1) + return topk_logits, topk_indices + + + + - return embgnn_logits, transformer_logits, topk_index #* adding lhs embedding code not working yet # def rerank(self, gnn_logits, rhs_gnn_embedding, index, lhs_embedding): @@ -219,25 +263,3 @@ def forward( # for idx in range(topk_index.shape[0]): # out_logits[idx][topk_index[idx]] = (lhs_embedding[idx] * tr_embed[idx]).sum(dim=-1).flatten() # return out_logits, topk_index - - - def rerank(self, gnn_logits, rhs_gnn_embedding, index, lhs_embedding): - """ - reranks the gnn logits based on the provided gnn embeddings. - rhs_gnn_embedding:[# rhs nodes, embed_dim] - """ - topk = self.rank_topk - _, topk_index = torch.topk(gnn_logits, self.rank_topk, dim=1) - embed_size = rhs_gnn_embedding.shape[1] - - # need input batch of size [# nodes, topk, embed_size] - top_embed = torch.stack([rhs_gnn_embedding[topk_index[idx]] for idx in range(topk_index.shape[0])]) - for block in self.tr_blocks: - tr_embed = block(top_embed, top_embed) # [# nodes, topk, embed_size] - - #! for top k prediction - out_logits = torch.full(gnn_logits.shape, -float('inf')).to(gnn_logits.device) - # tr_logits = torch.stack([(lhs_embedding[idx] * tr_embed[idx]).sum(dim=-1).flatten() for idx in range(topk_index.shape[0])]) - for idx in range(topk_index.shape[0]): - out_logits[idx][topk_index[idx]] = (lhs_embedding[idx] * tr_embed[idx]).sum(dim=-1).flatten() - return out_logits, topk_index \ No newline at end of file From 7622b94f3a8d130cfc2cd2e3948cd8a39966a746 Mon Sep 17 00:00:00 2001 From: shenyang huang Date: Fri, 23 Aug 2024 21:22:08 +0000 Subject: [PATCH 17/26] push current version --- benchmark/relbench_link_prediction_benchmark.py | 12 ++++++++++-- hybridgnn/nn/models/rerank_transformer.py | 14 +++++++++----- 2 files changed, 19 insertions(+), 7 deletions(-) diff --git a/benchmark/relbench_link_prediction_benchmark.py b/benchmark/relbench_link_prediction_benchmark.py index 30042eb..e1545c8 100644 --- a/benchmark/relbench_link_prediction_benchmark.py +++ b/benchmark/relbench_link_prediction_benchmark.py @@ -230,13 +230,18 @@ def train( idx_position = (torch.arange(batch_size) * num_rhs_nodes).view(-1,1).to(tr_logits.device) topk_idx = topk_idx + idx_position + """ + debug if this is correct + """ correct_label = torch.isin(topk_idx,src_batch * num_rhs_nodes + dst_index).float() loss += F.binary_cross_entropy_with_logits(tr_logits, correct_label) + numel = len(batch[task.dst_entity_table].batch) + + # true_label_index, mask = map_index(topk_idx, edge_label_index) # correct_label = torch.zeros(tr_logits.shape).to(tr_logits.device) # correct_label[mask] = True # loss += sparse_cross_entropy(tr_logits, edge_label_index) - numel = len(batch[task.dst_entity_table].batch) loss.backward() optimizer.step() @@ -289,7 +294,10 @@ def test(model: torch.nn.Module, loader: NeighborLoader, stage: str) -> float: #! need to change the shape of tr_logits scores = torch.zeros(batch_size, task.num_dst_nodes, device=tr_logits.device) - scores.scatter_(1, topk_idx, torch.sigmoid(tr_logits.detach())) + tr_logits = tr_logits.detach() + for i in range(scores.shape[0]): + scores[i][topk_idx[i]] = torch.sigmoid(tr_logits[i]) + # scores.scatter_(1, topk_idx, torch.sigmoid(tr_logits.detach())) # scores[topk_index] = torch.sigmoid(tr_logits.detach().flatten()) else: diff --git a/hybridgnn/nn/models/rerank_transformer.py b/hybridgnn/nn/models/rerank_transformer.py index 28ef133..d7904b0 100644 --- a/hybridgnn/nn/models/rerank_transformer.py +++ b/hybridgnn/nn/models/rerank_transformer.py @@ -99,7 +99,7 @@ def __init__( dropout=dropout, ) for _ in range(1) ]) - self.tr_lin = torch.nn.Linear(embedding_dim*2, 1) + self.tr_lin = torch.nn.Linear(embedding_dim*2, embedding_dim) self.channels = channels @@ -196,6 +196,7 @@ def rerank(self, gnn_logits, shallow_rhs_embed, rhs_idgnn_embed, rhs_idgnn_index num_rhs_nodes = shallow_rhs_embed.shape[0] filtered_logits, topk_indices = torch.topk(gnn_logits, self.rank_topk, dim=1) + out_indices = topk_indices.clone() # [batch_size, topk, embed_size] seq = shallow_rhs_embed[topk_indices.flatten()].view(batch_size * self.rank_topk, embed_size) rhs_idgnn_index = lhs_idgnn_batch * num_rhs_nodes + rhs_idgnn_index @@ -223,12 +224,15 @@ def rerank(self, gnn_logits, shallow_rhs_embed, rhs_idgnn_embed, rhs_idgnn_index #! just get the logit directly from transformer seq = seq.view(-1,embed_size*2) - seq = self.tr_lin(seq) - topk_logits = seq.view(batch_size,self.rank_topk) + seq = self.tr_lin(seq) # [batch_size, embed_size] + seq = seq.view(batch_size * self.rank_topk, embed_size) + lhs_uniq_embed = lhs_uniq_embed.reshape(batch_size * self.rank_topk, embed_size) - _, topk_indices = torch.topk(gnn_logits, self.rank_topk, dim=1) - return topk_logits, topk_indices + tr_logits = (lhs_uniq_embed.view(-1, embed_size) * seq.view(-1, embed_size)).sum( + dim=-1).flatten() + tr_logits = tr_logits.view(batch_size,self.rank_topk) + return tr_logits, out_indices From dc73f7c7f7739c01f4d6b9ce3ad183372f052749 Mon Sep 17 00:00:00 2001 From: shenyang huang Date: Fri, 23 Aug 2024 22:38:45 +0000 Subject: [PATCH 18/26] converging to the follow implementation Please enter the commit message for your changes. Lines starting --- .../relbench_link_prediction_benchmark.py | 60 ++++++++++--------- 1 file changed, 33 insertions(+), 27 deletions(-) diff --git a/benchmark/relbench_link_prediction_benchmark.py b/benchmark/relbench_link_prediction_benchmark.py index e1545c8..3115de5 100644 --- a/benchmark/relbench_link_prediction_benchmark.py +++ b/benchmark/relbench_link_prediction_benchmark.py @@ -164,7 +164,7 @@ "embedding_dim": [64], "norm": ["layer_norm", "batch_norm"], "dropout": [0.1, 0.2], - "rank_topk": [100] + "rank_topk": [500], } train_search_space = { "batch_size": [128, 256, 512], @@ -178,6 +178,7 @@ def train( optimizer: torch.optim.Optimizer, loader: NeighborLoader, train_sparse_tensor: SparseTensor, + epoch:int, ) -> float: model.train() @@ -222,26 +223,23 @@ def train( gnn_logits, tr_logits, topk_idx = model(batch, task.src_entity_table, task.dst_entity_table, task.dst_entity_col) edge_label_index = torch.stack([src_batch, dst_index], dim=0) loss = sparse_cross_entropy(gnn_logits, edge_label_index) - num_rhs_nodes = gnn_logits.shape[1] - - #* tr_logits: [batch_size, topk], we need to get the edges that exist in the topk prediction - batch_size = topk_idx.shape[0] - topk = topk_idx.shape[1] - idx_position = (torch.arange(batch_size) * num_rhs_nodes).view(-1,1).to(tr_logits.device) - topk_idx = topk_idx + idx_position - - """ - debug if this is correct - """ - correct_label = torch.isin(topk_idx,src_batch * num_rhs_nodes + dst_index).float() - loss += F.binary_cross_entropy_with_logits(tr_logits, correct_label) numel = len(batch[task.dst_entity_table].batch) - - # true_label_index, mask = map_index(topk_idx, edge_label_index) - # correct_label = torch.zeros(tr_logits.shape).to(tr_logits.device) - # correct_label[mask] = True - # loss += sparse_cross_entropy(tr_logits, edge_label_index) + if (epoch > 0): + # num_rhs_nodes = gnn_logits.shape[1] + # #* tr_logits: [batch_size, topk], we need to get the edges that exist in the topk prediction, likely incorrect + # batch_size = topk_idx.shape[0] + # topk = topk_idx.shape[1] + # idx_position = (torch.arange(batch_size) * num_rhs_nodes).view(-1,1).to(tr_logits.device) + # topk_idx = topk_idx + idx_position + # correct_label = torch.isin(topk_idx,src_batch * num_rhs_nodes + dst_index).float() + + #* approach with map_index + label_index, mask = map_index(topk_idx.view(-1), dst_index) + true_label = torch.zeros(topk_idx.shape).to(tr_logits.device) + true_label[mask.view(true_label.shape)] = 1.0 + loss += F.binary_cross_entropy_with_logits(tr_logits, true_label.float()) + loss.backward() optimizer.step() @@ -278,32 +276,40 @@ def test(model: torch.nn.Module, loader: NeighborLoader, stage: str) -> float: device=out.device) scores[batch[task.dst_entity_table].batch, batch[task.dst_entity_table].n_id] = torch.sigmoid(out) + _, pred_mini = torch.topk(scores, k=task.eval_k, dim=1) elif args.model in ["hybridgnn", "shallowrhsgnn"]: # Get ground-truth out = model(batch, task.src_entity_table, task.dst_entity_table).detach() scores = torch.sigmoid(out) + _, pred_mini = torch.topk(scores, k=task.eval_k, dim=1) elif args.model in ["rhstransformer"]: out = model(batch, task.src_entity_table, task.dst_entity_table).detach() scores = torch.sigmoid(out) + _, pred_mini = torch.topk(scores, k=task.eval_k, dim=1) elif args.model in ["rerank_transformer"]: gnn_logits, tr_logits, topk_idx = model(batch, task.src_entity_table, task.dst_entity_table, task.dst_entity_col) - #! need to change the shape of tr_logits - scores = torch.zeros(batch_size, task.num_dst_nodes, - device=tr_logits.device) - tr_logits = tr_logits.detach() - for i in range(scores.shape[0]): - scores[i][topk_idx[i]] = torch.sigmoid(tr_logits[i]) + + _, pred_idx = torch.topk(tr_logits.detach(), k=task.eval_k, dim=1) + pred_mini = topk_idx[torch.arange(topk_idx.size(0)).unsqueeze(1), pred_idx] + + #! to remove + # scores = torch.zeros(batch_size, task.num_dst_nodes, + # device=tr_logits.device) + # tr_logits = tr_logits.detach() + # scores.scatter_(1, topk_idx, torch.sigmoid(tr_logits.detach())) + + # for i in range(scores.shape[0]): + # scores[i][topk_idx[i]] = torch.sigmoid(tr_logits[i]) # scores.scatter_(1, topk_idx, torch.sigmoid(tr_logits.detach())) # scores[topk_index] = torch.sigmoid(tr_logits.detach().flatten()) else: raise ValueError(f"Unsupported model type: {args.model}.") - _, pred_mini = torch.topk(scores, k=task.eval_k, dim=1) pred_list.append(pred_mini) pred = torch.cat(pred_list, dim=0).cpu().numpy() @@ -371,7 +377,7 @@ def train_and_eval_with_cfg( train_sparse_tensor = SparseTensor(dst_nodes_dict["train"][1], device=device) train_loss = train(model, optimizer, loader_dict["train"], - train_sparse_tensor) + train_sparse_tensor, epoch) optimizer.zero_grad() val_metric = test(model, loader_dict["val"], "val") From 38f9cf481375ac7a8ed20ec46cc3bd2e315efeef Mon Sep 17 00:00:00 2001 From: shenyang huang Date: Fri, 23 Aug 2024 22:41:02 +0000 Subject: [PATCH 19/26] training both from epoch 0, potentially try better ways --- .../relbench_link_prediction_benchmark.py | 27 +++++++++---------- hybridgnn/nn/models/rerank_transformer.py | 1 - 2 files changed, 13 insertions(+), 15 deletions(-) diff --git a/benchmark/relbench_link_prediction_benchmark.py b/benchmark/relbench_link_prediction_benchmark.py index 3115de5..13ddae0 100644 --- a/benchmark/relbench_link_prediction_benchmark.py +++ b/benchmark/relbench_link_prediction_benchmark.py @@ -225,20 +225,19 @@ def train( loss = sparse_cross_entropy(gnn_logits, edge_label_index) numel = len(batch[task.dst_entity_table].batch) - if (epoch > 0): - # num_rhs_nodes = gnn_logits.shape[1] - # #* tr_logits: [batch_size, topk], we need to get the edges that exist in the topk prediction, likely incorrect - # batch_size = topk_idx.shape[0] - # topk = topk_idx.shape[1] - # idx_position = (torch.arange(batch_size) * num_rhs_nodes).view(-1,1).to(tr_logits.device) - # topk_idx = topk_idx + idx_position - # correct_label = torch.isin(topk_idx,src_batch * num_rhs_nodes + dst_index).float() - - #* approach with map_index - label_index, mask = map_index(topk_idx.view(-1), dst_index) - true_label = torch.zeros(topk_idx.shape).to(tr_logits.device) - true_label[mask.view(true_label.shape)] = 1.0 - loss += F.binary_cross_entropy_with_logits(tr_logits, true_label.float()) + # num_rhs_nodes = gnn_logits.shape[1] + # #* tr_logits: [batch_size, topk], we need to get the edges that exist in the topk prediction, likely incorrect + # batch_size = topk_idx.shape[0] + # topk = topk_idx.shape[1] + # idx_position = (torch.arange(batch_size) * num_rhs_nodes).view(-1,1).to(tr_logits.device) + # topk_idx = topk_idx + idx_position + # correct_label = torch.isin(topk_idx,src_batch * num_rhs_nodes + dst_index).float() + + #* approach with map_index + label_index, mask = map_index(topk_idx.view(-1), dst_index) + true_label = torch.zeros(topk_idx.shape).to(tr_logits.device) + true_label[mask.view(true_label.shape)] = 1.0 + loss += F.binary_cross_entropy_with_logits(tr_logits, true_label.float()) loss.backward() diff --git a/hybridgnn/nn/models/rerank_transformer.py b/hybridgnn/nn/models/rerank_transformer.py index d7904b0..874e209 100644 --- a/hybridgnn/nn/models/rerank_transformer.py +++ b/hybridgnn/nn/models/rerank_transformer.py @@ -189,7 +189,6 @@ def rerank(self, gnn_logits, shallow_rhs_embed, rhs_idgnn_embed, rhs_idgnn_index """ reranks the gnn logits based on the provided gnn embeddings. shallow_rhs_embed:[# rhs nodes, embed_dim] - """ embed_size = rhs_idgnn_embed.shape[1] batch_size = gnn_logits.shape[0] From 8e846dd2bf9dfcf6202b4f084546b006b0af54da Mon Sep 17 00:00:00 2001 From: shenyang huang Date: Sun, 25 Aug 2024 18:28:03 +0000 Subject: [PATCH 20/26] for transformer, not training with nodes whose prediction is not within topk --- benchmark/relbench_link_prediction_benchmark.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/benchmark/relbench_link_prediction_benchmark.py b/benchmark/relbench_link_prediction_benchmark.py index 13ddae0..c7e3a6b 100644 --- a/benchmark/relbench_link_prediction_benchmark.py +++ b/benchmark/relbench_link_prediction_benchmark.py @@ -164,7 +164,7 @@ "embedding_dim": [64], "norm": ["layer_norm", "batch_norm"], "dropout": [0.1, 0.2], - "rank_topk": [500], + "rank_topk": [200], } train_search_space = { "batch_size": [128, 256, 512], @@ -237,6 +237,13 @@ def train( label_index, mask = map_index(topk_idx.view(-1), dst_index) true_label = torch.zeros(topk_idx.shape).to(tr_logits.device) true_label[mask.view(true_label.shape)] = 1.0 + + #* empty label rows + nonzero_mask = torch.any(true_label, dim=1) + tr_logits = tr_logits[nonzero_mask] + true_label = true_label[nonzero_mask] + + #* the loss of the transformer should be scaled down? ((topk_idx.shape[1] / gnn_logits.shape[1])) loss += F.binary_cross_entropy_with_logits(tr_logits, true_label.float()) loss.backward() From 41a38cf10f1ca77a783419512af91bf14c67b21c Mon Sep 17 00:00:00 2001 From: shenyang huang Date: Mon, 26 Aug 2024 18:01:05 +0000 Subject: [PATCH 21/26] commiting before clean up --- benchmark/relbench_link_prediction_benchmark.py | 14 +++++++------- hybridgnn/nn/models/rerank_transformer.py | 3 +-- 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/benchmark/relbench_link_prediction_benchmark.py b/benchmark/relbench_link_prediction_benchmark.py index c7e3a6b..663bbfd 100644 --- a/benchmark/relbench_link_prediction_benchmark.py +++ b/benchmark/relbench_link_prediction_benchmark.py @@ -164,7 +164,7 @@ "embedding_dim": [64], "norm": ["layer_norm", "batch_norm"], "dropout": [0.1, 0.2], - "rank_topk": [200], + "rank_topk": [100], } train_search_space = { "batch_size": [128, 256, 512], @@ -232,20 +232,20 @@ def train( # idx_position = (torch.arange(batch_size) * num_rhs_nodes).view(-1,1).to(tr_logits.device) # topk_idx = topk_idx + idx_position # correct_label = torch.isin(topk_idx,src_batch * num_rhs_nodes + dst_index).float() - + #* approach with map_index label_index, mask = map_index(topk_idx.view(-1), dst_index) true_label = torch.zeros(topk_idx.shape).to(tr_logits.device) true_label[mask.view(true_label.shape)] = 1.0 - #* empty label rows - nonzero_mask = torch.any(true_label, dim=1) - tr_logits = tr_logits[nonzero_mask] - true_label = true_label[nonzero_mask] + # #* empty label rows + # nonzero_mask = torch.any(true_label, dim=1) + # tr_logits = tr_logits[nonzero_mask] + # true_label = true_label[nonzero_mask] #* the loss of the transformer should be scaled down? ((topk_idx.shape[1] / gnn_logits.shape[1])) loss += F.binary_cross_entropy_with_logits(tr_logits, true_label.float()) - + loss.backward() optimizer.step() diff --git a/hybridgnn/nn/models/rerank_transformer.py b/hybridgnn/nn/models/rerank_transformer.py index 874e209..1915a10 100644 --- a/hybridgnn/nn/models/rerank_transformer.py +++ b/hybridgnn/nn/models/rerank_transformer.py @@ -208,8 +208,7 @@ def rerank(self, gnn_logits, shallow_rhs_embed, rhs_idgnn_embed, rhs_idgnn_index logit_mask[mask] = True seq = torch.where(logit_mask, id_gnn_seq.view(-1,embed_size), seq.view(-1,embed_size)) - unique_lhs_idx = torch.unique(lhs_idgnn_batch) - lhs_uniq_embed = lhs_embedding[unique_lhs_idx] + lhs_uniq_embed = lhs_embedding[:batch_size] seq = seq.clone() seq = seq.view(batch_size,self.rank_topk,-1) From a181f453c12e1a854f77d5289e582b3b1220d4ca Mon Sep 17 00:00:00 2001 From: shenyang huang Date: Tue, 27 Aug 2024 19:32:03 +0000 Subject: [PATCH 22/26] semi working version of rerank transformer, still needs work for improvement --- .../relbench_link_prediction_benchmark.py | 34 ++++++++------ hybridgnn/nn/models/rerank_transformer.py | 46 ++++--------------- 2 files changed, 28 insertions(+), 52 deletions(-) diff --git a/benchmark/relbench_link_prediction_benchmark.py b/benchmark/relbench_link_prediction_benchmark.py index 663bbfd..62df796 100644 --- a/benchmark/relbench_link_prediction_benchmark.py +++ b/benchmark/relbench_link_prediction_benchmark.py @@ -163,8 +163,8 @@ "channels": [64], "embedding_dim": [64], "norm": ["layer_norm", "batch_norm"], - "dropout": [0.1, 0.2], - "rank_topk": [100], + "dropout": [0.0], + "rank_topk": [200], } train_search_space = { "batch_size": [128, 256, 512], @@ -224,14 +224,6 @@ def train( edge_label_index = torch.stack([src_batch, dst_index], dim=0) loss = sparse_cross_entropy(gnn_logits, edge_label_index) numel = len(batch[task.dst_entity_table].batch) - - # num_rhs_nodes = gnn_logits.shape[1] - # #* tr_logits: [batch_size, topk], we need to get the edges that exist in the topk prediction, likely incorrect - # batch_size = topk_idx.shape[0] - # topk = topk_idx.shape[1] - # idx_position = (torch.arange(batch_size) * num_rhs_nodes).view(-1,1).to(tr_logits.device) - # topk_idx = topk_idx + idx_position - # correct_label = torch.isin(topk_idx,src_batch * num_rhs_nodes + dst_index).float() #* approach with map_index label_index, mask = map_index(topk_idx.view(-1), dst_index) @@ -239,9 +231,9 @@ def train( true_label[mask.view(true_label.shape)] = 1.0 # #* empty label rows - # nonzero_mask = torch.any(true_label, dim=1) - # tr_logits = tr_logits[nonzero_mask] - # true_label = true_label[nonzero_mask] + nonzero_mask = torch.any(true_label, dim=1) + tr_logits = tr_logits[nonzero_mask] + true_label = true_label[nonzero_mask] #* the loss of the transformer should be scaled down? ((topk_idx.shape[1] / gnn_logits.shape[1])) loss += F.binary_cross_entropy_with_logits(tr_logits, true_label.float()) @@ -298,10 +290,22 @@ def test(model: torch.nn.Module, loader: NeighborLoader, stage: str) -> float: gnn_logits, tr_logits, topk_idx = model(batch, task.src_entity_table, task.dst_entity_table, task.dst_entity_col) - - _, pred_idx = torch.topk(tr_logits.detach(), k=task.eval_k, dim=1) + + _, pred_idx = torch.topk(torch.sigmoid(tr_logits).detach(), k=task.eval_k, dim=1) pred_mini = topk_idx[torch.arange(topk_idx.size(0)).unsqueeze(1), pred_idx] + # _, pred_mini = torch.topk(torch.sigmoid(gnn_logits.detach()), k=task.eval_k, dim=1) + # gnn_out = pred_mini[0] + # sort_out, _ = torch.sort(gnn_out) + # gnn_out = sort_out + + # tr_out = pred_mini[0] + # sort_out, _ = torch.sort(tr_out) + # tr_out = sort_out + + # assert torch.equal(gnn_out, tr_out) + + #! to remove # scores = torch.zeros(batch_size, task.num_dst_nodes, # device=tr_logits.device) diff --git a/hybridgnn/nn/models/rerank_transformer.py b/hybridgnn/nn/models/rerank_transformer.py index 1915a10..201d7ba 100644 --- a/hybridgnn/nn/models/rerank_transformer.py +++ b/hybridgnn/nn/models/rerank_transformer.py @@ -99,7 +99,8 @@ def __init__( dropout=dropout, ) for _ in range(1) ]) - self.tr_lin = torch.nn.Linear(embedding_dim*2, embedding_dim) + # self.tr_lin = torch.nn.Linear(embedding_dim*2, embedding_dim) + self.tr_lin = torch.nn.Linear(embedding_dim*2,1) self.channels = channels @@ -198,7 +199,7 @@ def rerank(self, gnn_logits, shallow_rhs_embed, rhs_idgnn_embed, rhs_idgnn_index out_indices = topk_indices.clone() # [batch_size, topk, embed_size] seq = shallow_rhs_embed[topk_indices.flatten()].view(batch_size * self.rank_topk, embed_size) - rhs_idgnn_index = lhs_idgnn_batch * num_rhs_nodes + rhs_idgnn_index + # rhs_idgnn_index = lhs_idgnn_batch * num_rhs_nodes + rhs_idgnn_index query_rhs_idgnn_index, mask = map_index(topk_indices.view(-1), rhs_idgnn_index) id_gnn_seq = torch.zeros(batch_size * self.rank_topk, embed_size) @@ -222,12 +223,13 @@ def rerank(self, gnn_logits, shallow_rhs_embed, rhs_idgnn_embed, rhs_idgnn_index #! just get the logit directly from transformer seq = seq.view(-1,embed_size*2) - seq = self.tr_lin(seq) # [batch_size, embed_size] - seq = seq.view(batch_size * self.rank_topk, embed_size) - lhs_uniq_embed = lhs_uniq_embed.reshape(batch_size * self.rank_topk, embed_size) + tr_logits = self.tr_lin(seq) # [batch_size, embed_size] + # seq = seq.view(batch_size * self.rank_topk, embed_size) + # lhs_uniq_embed = lhs_uniq_embed.reshape(batch_size * self.rank_topk, embed_size) + + # tr_logits = (lhs_uniq_embed.view(-1, embed_size) * seq.view(-1, embed_size)).sum( + # dim=-1).flatten() - tr_logits = (lhs_uniq_embed.view(-1, embed_size) * seq.view(-1, embed_size)).sum( - dim=-1).flatten() tr_logits = tr_logits.view(batch_size,self.rank_topk) return tr_logits, out_indices @@ -235,33 +237,3 @@ def rerank(self, gnn_logits, shallow_rhs_embed, rhs_idgnn_embed, rhs_idgnn_index - - - #* adding lhs embedding code not working yet - # def rerank(self, gnn_logits, rhs_gnn_embedding, index, lhs_embedding): - # """ - # reranks the gnn logits based on the provided gnn embeddings. - # rhs_gnn_embedding:[# rhs nodes, embed_dim] - # """ - # topk = self.rank_topk - # _, topk_index = torch.topk(gnn_logits, self.rank_topk, dim=1) - # embed_size = rhs_gnn_embedding.shape[1] - - # # need input batch of size [# nodes, topk, embed_size] - # #! concatenate the lhs embedding with rhs embedding - # top_embed = torch.stack([torch.cat((rhs_gnn_embedding[topk_index[idx]],lhs_embedding[idx].view(1,-1).expand(self.rank_topk,-1)), dim=1) for idx in range(topk_index.shape[0])]) - # tr_embed = top_embed - # for block in self.tr_blocks: - # tr_embed = block(tr_embed, tr_embed) # [# nodes, topk, embed_size] - - # tr_embed = tr_embed.view(-1,embed_size*2) - # tr_embed = self.tr_lin(tr_embed) - # tr_embed = tr_embed.view(-1,self.rank_topk,embed_size) - - - # #! for top k prediction - # out_logits = torch.full(gnn_logits.shape, -float('inf')).to(gnn_logits.device) - # # tr_logits = torch.stack([(lhs_embedding[idx] * tr_embed[idx]).sum(dim=-1).flatten() for idx in range(topk_index.shape[0])]) - # for idx in range(topk_index.shape[0]): - # out_logits[idx][topk_index[idx]] = (lhs_embedding[idx] * tr_embed[idx]).sum(dim=-1).flatten() - # return out_logits, topk_index From b86641fe12eb3b2adfcf1c4f9ad368f1d5ce3f73 Mon Sep 17 00:00:00 2001 From: shenyang huang Date: Wed, 28 Aug 2024 21:59:52 +0000 Subject: [PATCH 23/26] somewhat stable version of the idea --- .../relbench_link_prediction_benchmark.py | 88 +++++++------------ hybridgnn/nn/models/rerank_transformer.py | 40 +++++---- 2 files changed, 55 insertions(+), 73 deletions(-) diff --git a/benchmark/relbench_link_prediction_benchmark.py b/benchmark/relbench_link_prediction_benchmark.py index 62df796..31a9a78 100644 --- a/benchmark/relbench_link_prediction_benchmark.py +++ b/benchmark/relbench_link_prediction_benchmark.py @@ -39,7 +39,7 @@ from torch_geometric.utils import index_to_mask from torch_geometric.utils.map import map_index - +PRETRAIN_EPOCH = 1 TRAIN_CONFIG_KEYS = ["batch_size", "gamma_rate", "base_lr"] LINK_PREDICTION_METRIC = "link_prediction_map" @@ -72,8 +72,7 @@ parser.add_argument("--result_path", type=str, default="result") args = parser.parse_args() -# device = torch.device("cuda" if torch.cuda.is_available() else "cpu") -device = "cpu" +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if torch.cuda.is_available(): torch.set_num_threads(1) seed_everything(args.seed) @@ -153,23 +152,24 @@ train_search_space = { "batch_size": [128, 256], "base_lr": [0.0005, 0.01], - "gamma_rate": [0.9, 1.0], + "gamma_rate": [0.8, 1.0], } model_cls = RHSTransformer elif args.model in ["rerank_transformer"]: model_search_space = { "encoder_channels": [64, 128, 256], "encoder_layers": [2, 4, 8], - "channels": [64], - "embedding_dim": [64], + "channels": [64,128,256], "norm": ["layer_norm", "batch_norm"], - "dropout": [0.0], - "rank_topk": [200], + "dropout": [0.0,0.1,0.2], + "t_encoding_type": ["fuse", "absolute"], + "rank_topk": [100,150,200], + "num_tr_layers": [1,2,3], } train_search_space = { - "batch_size": [128, 256, 512], + "batch_size": [256, 512], "base_lr": [0.0005, 0.01], - "gamma_rate": [0.9, 1.0], + "gamma_rate": [0.8,1.0], } model_cls = ReRankTransformer @@ -220,23 +220,22 @@ def train( loss = sparse_cross_entropy(logits, edge_label_index) numel = len(batch[task.dst_entity_table].batch) elif args.model in ["rerank_transformer"]: - gnn_logits, tr_logits, topk_idx = model(batch, task.src_entity_table, task.dst_entity_table, task.dst_entity_col) - edge_label_index = torch.stack([src_batch, dst_index], dim=0) - loss = sparse_cross_entropy(gnn_logits, edge_label_index) numel = len(batch[task.dst_entity_table].batch) - - #* approach with map_index - label_index, mask = map_index(topk_idx.view(-1), dst_index) - true_label = torch.zeros(topk_idx.shape).to(tr_logits.device) - true_label[mask.view(true_label.shape)] = 1.0 - - # #* empty label rows - nonzero_mask = torch.any(true_label, dim=1) - tr_logits = tr_logits[nonzero_mask] - true_label = true_label[nonzero_mask] - - #* the loss of the transformer should be scaled down? ((topk_idx.shape[1] / gnn_logits.shape[1])) - loss += F.binary_cross_entropy_with_logits(tr_logits, true_label.float()) + gnn_logits, tr_logits, topk_idx = model(batch, task.src_entity_table, task.dst_entity_table, task.dst_entity_col) + if (epoch <= PRETRAIN_EPOCH): + edge_label_index = torch.stack([src_batch, dst_index], dim=0) + loss = sparse_cross_entropy(gnn_logits, edge_label_index) + else: + #* approach with map_index + label_index, mask = map_index(topk_idx.view(-1), dst_index) + true_label = torch.zeros(topk_idx.shape).to(tr_logits.device) + true_label[mask.view(true_label.shape)] = 1.0 + + # # #* empty label rows + # nonzero_mask = torch.any(true_label, dim=1) + # tr_logits = tr_logits[nonzero_mask] + # true_label = true_label[nonzero_mask] + loss = F.binary_cross_entropy_with_logits(tr_logits, true_label.float()) loss.backward() @@ -259,7 +258,7 @@ def train( @torch.no_grad() -def test(model: torch.nn.Module, loader: NeighborLoader, stage: str) -> float: +def test(model: torch.nn.Module, loader: NeighborLoader, stage: str, epoch:int,) -> float: model.eval() pred_list: List[Tensor] = [] @@ -290,32 +289,13 @@ def test(model: torch.nn.Module, loader: NeighborLoader, stage: str) -> float: gnn_logits, tr_logits, topk_idx = model(batch, task.src_entity_table, task.dst_entity_table, task.dst_entity_col) - - _, pred_idx = torch.topk(torch.sigmoid(tr_logits).detach(), k=task.eval_k, dim=1) - pred_mini = topk_idx[torch.arange(topk_idx.size(0)).unsqueeze(1), pred_idx] - - # _, pred_mini = torch.topk(torch.sigmoid(gnn_logits.detach()), k=task.eval_k, dim=1) - # gnn_out = pred_mini[0] - # sort_out, _ = torch.sort(gnn_out) - # gnn_out = sort_out - - # tr_out = pred_mini[0] - # sort_out, _ = torch.sort(tr_out) - # tr_out = sort_out - - # assert torch.equal(gnn_out, tr_out) - - - #! to remove - # scores = torch.zeros(batch_size, task.num_dst_nodes, - # device=tr_logits.device) - # tr_logits = tr_logits.detach() - # scores.scatter_(1, topk_idx, torch.sigmoid(tr_logits.detach())) - # for i in range(scores.shape[0]): - # scores[i][topk_idx[i]] = torch.sigmoid(tr_logits[i]) - # scores.scatter_(1, topk_idx, torch.sigmoid(tr_logits.detach())) - # scores[topk_index] = torch.sigmoid(tr_logits.detach().flatten()) + if (epoch <= PRETRAIN_EPOCH): + scores = torch.sigmoid(gnn_logits.detach()) + _, pred_mini = torch.topk(scores, k=task.eval_k, dim=1) + else: + _, pred_idx = torch.topk(torch.sigmoid(tr_logits.detach()), k=task.eval_k, dim=1) + pred_mini = topk_idx[torch.arange(topk_idx.size(0)).unsqueeze(1), pred_idx] else: raise ValueError(f"Unsupported model type: {args.model}.") @@ -389,11 +369,11 @@ def train_and_eval_with_cfg( train_loss = train(model, optimizer, loader_dict["train"], train_sparse_tensor, epoch) optimizer.zero_grad() - val_metric = test(model, loader_dict["val"], "val") + val_metric = test(model, loader_dict["val"], "val", epoch) if val_metric > best_val_metric: best_val_metric = val_metric - best_test_metric = test(model, loader_dict["test"], "test") + best_test_metric = test(model, loader_dict["test"], "test", epoch) lr_scheduler.step() print(f"Train Loss: {train_loss:.4f}, Val: {val_metric:.4f}") diff --git a/hybridgnn/nn/models/rerank_transformer.py b/hybridgnn/nn/models/rerank_transformer.py index 201d7ba..0fd0f3c 100644 --- a/hybridgnn/nn/models/rerank_transformer.py +++ b/hybridgnn/nn/models/rerank_transformer.py @@ -28,13 +28,13 @@ class ReRankTransformer(torch.nn.Module): col_stats_dict (Dict[str, Dict[str, Dict[StatType, Any]]]): column stats num_nodes (int): number of nodes, num_layers (int): number of mp layers, - channels (int): input dimension, - embedding_dim (int): embedding dimension size, + channels (int): input dimension and embedding dimension aggr (str): aggregation type, norm (norm): normalization type, dropout (float): dropout rate for the transformer float, heads (int): number of attention heads, - rank_topk (int): how many top results of gnn would be reranked,""" + rank_topk (int): how many top results of gnn would be reranked, + num_tr_layers (int): number of transformer layers,""" def __init__( self, data: HeteroData, @@ -42,12 +42,13 @@ def __init__( num_nodes: int, num_layers: int, channels: int, - embedding_dim: int, aggr: str = 'sum', norm: str = 'layer_norm', dropout: float = 0.2, heads: int = 1, rank_topk: int = 100, + t_encoding_type: str = "absolute", + num_tr_layers: int = 1, torch_frame_model_cls: Type[torch.nn.Module] = ResNet, torch_frame_model_kwargs: Optional[Dict[str, Any]] = None, ) -> None: @@ -70,6 +71,7 @@ def __init__( if "time" in data[node_type] ], channels=channels, + encoding_type=t_encoding_type, ) self.gnn = HeteroGraphSAGE( node_types=data.node_types, @@ -84,23 +86,24 @@ def __init__( norm=norm, num_layers=1, ) - self.lhs_projector = torch.nn.Linear(channels, embedding_dim) + self.lhs_projector = torch.nn.Linear(channels, channels) self.id_awareness_emb = torch.nn.Embedding(1, channels) - self.rhs_embedding = torch.nn.Embedding(num_nodes, embedding_dim) - self.lin_offset_idgnn = torch.nn.Linear(embedding_dim, 1) - self.lin_offset_embgnn = torch.nn.Linear(embedding_dim, 1) + self.rhs_embedding = torch.nn.Embedding(num_nodes, channels) + self.lin_offset_idgnn = torch.nn.Linear(channels, 1) + self.lin_offset_embgnn = torch.nn.Linear(channels, 1) self.rank_topk = rank_topk + + self.tr_embed_size = channels * 2 self.tr_blocks = torch.nn.ModuleList([ MultiheadAttentionBlock( - channels=embedding_dim*2, + channels=self.tr_embed_size, heads=heads, layer_norm=True, dropout=dropout, - ) for _ in range(1) + ) for _ in range(num_tr_layers) ]) - # self.tr_lin = torch.nn.Linear(embedding_dim*2, embedding_dim) - self.tr_lin = torch.nn.Linear(embedding_dim*2,1) + self.tr_lin = torch.nn.Linear(self.tr_embed_size,1) self.channels = channels @@ -199,19 +202,18 @@ def rerank(self, gnn_logits, shallow_rhs_embed, rhs_idgnn_embed, rhs_idgnn_index out_indices = topk_indices.clone() # [batch_size, topk, embed_size] seq = shallow_rhs_embed[topk_indices.flatten()].view(batch_size * self.rank_topk, embed_size) - # rhs_idgnn_index = lhs_idgnn_batch * num_rhs_nodes + rhs_idgnn_index query_rhs_idgnn_index, mask = map_index(topk_indices.view(-1), rhs_idgnn_index) - id_gnn_seq = torch.zeros(batch_size * self.rank_topk, embed_size) + id_gnn_seq = torch.zeros(batch_size * self.rank_topk, embed_size).to(rhs_idgnn_embed.device) id_gnn_seq[mask] = rhs_idgnn_embed[query_rhs_idgnn_index] - logit_mask = torch.zeros(batch_size * self.rank_topk, embed_size, dtype=bool) + logit_mask = torch.zeros(batch_size * self.rank_topk, embed_size, dtype=bool).to(rhs_idgnn_embed.device) logit_mask[mask] = True seq = torch.where(logit_mask, id_gnn_seq.view(-1,embed_size), seq.view(-1,embed_size)) lhs_uniq_embed = lhs_embedding[:batch_size] - seq = seq.clone() + # seq = seq.clone() seq = seq.view(batch_size,self.rank_topk,-1) lhs_uniq_embed = lhs_uniq_embed.view(-1,1,embed_size) @@ -222,16 +224,16 @@ def rerank(self, gnn_logits, shallow_rhs_embed, rhs_idgnn_embed, rhs_idgnn_index seq = block(seq, seq) # [# nodes, topk, embed_size] #! just get the logit directly from transformer - seq = seq.view(-1,embed_size*2) + seq = seq.reshape(-1,self.tr_embed_size) tr_logits = self.tr_lin(seq) # [batch_size, embed_size] + tr_logits = tr_logits.view(batch_size,self.rank_topk) + # seq = seq.view(batch_size * self.rank_topk, embed_size) # lhs_uniq_embed = lhs_uniq_embed.reshape(batch_size * self.rank_topk, embed_size) # tr_logits = (lhs_uniq_embed.view(-1, embed_size) * seq.view(-1, embed_size)).sum( # dim=-1).flatten() - tr_logits = tr_logits.view(batch_size,self.rank_topk) - return tr_logits, out_indices From 66db12f2f2b449a61a5f386bcb428b924255b765 Mon Sep 17 00:00:00 2001 From: shenyang huang Date: Fri, 30 Aug 2024 17:35:24 +0000 Subject: [PATCH 24/26] adding test for max_map --- .../relbench_link_prediction_benchmark.py | 14 +- benchmark/test_max_map.py | 507 ++++++++++++++++++ 2 files changed, 518 insertions(+), 3 deletions(-) create mode 100644 benchmark/test_max_map.py diff --git a/benchmark/relbench_link_prediction_benchmark.py b/benchmark/relbench_link_prediction_benchmark.py index 31a9a78..4b85db7 100644 --- a/benchmark/relbench_link_prediction_benchmark.py +++ b/benchmark/relbench_link_prediction_benchmark.py @@ -39,6 +39,11 @@ from torch_geometric.utils import index_to_mask from torch_geometric.utils.map import map_index +from relbench.metrics import ( + link_prediction_map, +) + + PRETRAIN_EPOCH = 1 TRAIN_CONFIG_KEYS = ["batch_size", "gamma_rate", "base_lr"] @@ -159,12 +164,12 @@ model_search_space = { "encoder_channels": [64, 128, 256], "encoder_layers": [2, 4, 8], - "channels": [64,128,256], + "channels": [64,128], "norm": ["layer_norm", "batch_norm"], "dropout": [0.0,0.1,0.2], "t_encoding_type": ["fuse", "absolute"], "rank_topk": [100,150,200], - "num_tr_layers": [1,2,3], + "num_tr_layers": [1,2], } train_search_space = { "batch_size": [256, 512], @@ -185,6 +190,7 @@ def train( loss_accum = count_accum = 0 steps = 0 total_steps = min(len(loader), args.max_steps_per_epoch) + for batch in tqdm(loader, total=total_steps, desc="Train"): batch = batch.to(device) @@ -231,10 +237,11 @@ def train( true_label = torch.zeros(topk_idx.shape).to(tr_logits.device) true_label[mask.view(true_label.shape)] = 1.0 - # # #* empty label rows + # # # #* empty label rows # nonzero_mask = torch.any(true_label, dim=1) # tr_logits = tr_logits[nonzero_mask] # true_label = true_label[nonzero_mask] + loss = F.binary_cross_entropy_with_logits(tr_logits, true_label.float()) loss.backward() @@ -320,6 +327,7 @@ def train_and_eval_with_cfg( table_input = get_link_train_table_input(table, task) dst_nodes_dict[split] = table_input.dst_nodes num_dst_nodes_dict[split] = table_input.num_dst_nodes + loader_dict[split] = NeighborLoader( data, num_neighbors=num_neighbors, diff --git a/benchmark/test_max_map.py b/benchmark/test_max_map.py new file mode 100644 index 0000000..52a4bed --- /dev/null +++ b/benchmark/test_max_map.py @@ -0,0 +1,507 @@ +""" +$ python relbench_link_prediction_benchmark.py --dataset rel-stack --task post-post-related --model rhstransformer --num_trials 10 +""" + +import argparse +import json +import os +import os.path as osp +import time +import warnings +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple, Type, Union + +import numpy as np +import optuna +import torch +import torch.nn.functional as F +from relbench.base import Dataset, RecommendationTask, TaskType +from relbench.datasets import get_dataset +from relbench.modeling.graph import ( + get_link_train_table_input, + make_pkey_fkey_graph, +) +from relbench.modeling.loader import SparseTensor +from relbench.modeling.utils import get_stype_proposal +from relbench.tasks import get_task +from torch import Tensor +from torch.optim.lr_scheduler import ExponentialLR +from torch_frame import stype +from torch_frame.config.text_embedder import TextEmbedderConfig +from torch_geometric.loader import NeighborLoader +from torch_geometric.seed import seed_everything +from torch_geometric.typing import NodeType +from torch_geometric.utils.cross_entropy import sparse_cross_entropy +from tqdm import tqdm + +from hybridgnn.nn.models import IDGNN, HybridGNN, ShallowRHSGNN, RHSTransformer, ReRankTransformer +from hybridgnn.utils import GloveTextEmbedding +from torch_geometric.utils import index_to_mask +from torch_geometric.utils.map import map_index + +from relbench.metrics import ( + link_prediction_map, +) + + +PRETRAIN_EPOCH = 0 + +TRAIN_CONFIG_KEYS = ["batch_size", "gamma_rate", "base_lr"] +LINK_PREDICTION_METRIC = "link_prediction_map" + +parser = argparse.ArgumentParser() +parser.add_argument("--dataset", type=str, default="rel-stack") +parser.add_argument("--task", type=str, default="user-post-comment") +parser.add_argument( + "--model", + type=str, + default="hybridgnn", + choices=["hybridgnn", "idgnn", "shallowrhsgnn", "rhstransformer", "rerank_transformer"], +) +parser.add_argument("--epochs", type=int, default=20) +parser.add_argument("--num_trials", type=int, default=10, + help="Number of Optuna-based hyper-parameter tuning.") +parser.add_argument( + "--num_repeats", type=int, default=5, + help="Number of repeated training and eval on the best config.") +parser.add_argument("--eval_epochs_interval", type=int, default=1) +parser.add_argument("--num_layers", type=int, default=2) +parser.add_argument("--num_neighbors", type=int, default=128) +parser.add_argument("--temporal_strategy", type=str, default="last", + choices=["last", "uniform"]) +parser.add_argument("--max_steps_per_epoch", type=int, default=2000) +parser.add_argument("--num_workers", type=int, default=0) +parser.add_argument("--seed", type=int, default=42) +parser.add_argument("--cache_dir", type=str, + default=os.path.expanduser("~/.cache/relbench_examples")) +parser.add_argument("--result_path", type=str, default="result") +args = parser.parse_args() + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +if torch.cuda.is_available(): + torch.set_num_threads(1) +seed_everything(args.seed) + +if args.dataset == "rel-trial": + args.num_layers = 4 + +dataset: Dataset = get_dataset(args.dataset, download=True) +task: RecommendationTask = get_task(args.dataset, args.task, download=True) +tune_metric = LINK_PREDICTION_METRIC +assert task.task_type == TaskType.LINK_PREDICTION + +stypes_cache_path = Path(f"{args.cache_dir}/{args.dataset}/stypes.json") +try: + with open(stypes_cache_path, "r") as f: + col_to_stype_dict = json.load(f) + for table, col_to_stype in col_to_stype_dict.items(): + for col, stype_str in col_to_stype.items(): + col_to_stype[col] = stype(stype_str) +except FileNotFoundError: + col_to_stype_dict = get_stype_proposal(dataset.get_db()) + Path(stypes_cache_path).parent.mkdir(parents=True, exist_ok=True) + with open(stypes_cache_path, "w") as f: + json.dump(col_to_stype_dict, f, indent=2, default=str) + +data, col_stats_dict = make_pkey_fkey_graph( + dataset.get_db(), + col_to_stype_dict=col_to_stype_dict, + text_embedder_cfg=TextEmbedderConfig( + text_embedder=GloveTextEmbedding(device=device), batch_size=256), + cache_dir=f"{args.cache_dir}/{args.dataset}/materialized", +) + +num_neighbors = [ + int(args.num_neighbors // 2**i) for i in range(args.num_layers) +] + +model_cls: Type[Union[IDGNN, HybridGNN, ShallowRHSGNN, RHSTransformer, ReRankTransformer]] + +if args.model == "idgnn": + model_search_space = { + "encoder_channels": [64, 128, 256], + "encoder_layers": [2, 4, 8], + "channels": [64, 128, 256], + "norm": ["layer_norm", "batch_norm"] + } + train_search_space = { + "batch_size": [256, 512], + "base_lr": [0.0001, 0.01], + "gamma_rate": [0.9, 0.95, 1.], + } + model_cls = IDGNN +elif args.model in ["hybridgnn", "shallowrhsgnn"]: + model_search_space = { + "encoder_channels": [64, 128, 256], + "encoder_layers": [2, 4, 8], + "channels": [64, 128, 256], + "embedding_dim": [64, 128, 256], + "norm": ["layer_norm", "batch_norm"] + } + train_search_space = { + "batch_size": [256, 512], + "base_lr": [0.001, 0.01], + "gamma_rate": [0.8, 1.], + } + model_cls = (HybridGNN if args.model == "hybridgnn" else ShallowRHSGNN) +elif args.model in ["rhstransformer"]: + model_search_space = { + "encoder_channels": [64, 128], + "encoder_layers": [2, 4], + "channels": [64, 128], + "embedding_dim": [64, 128], + "norm": ["layer_norm", "batch_norm"], + "dropout": [0.1, 0.2], + "t_encoding_type": ["fuse", "absolute"], + } + train_search_space = { + "batch_size": [128, 256], + "base_lr": [0.0005, 0.01], + "gamma_rate": [0.8, 1.0], + } + model_cls = RHSTransformer +elif args.model in ["rerank_transformer"]: + model_search_space = { + "encoder_channels": [256], + "encoder_layers": [8], + "channels": [128], + "norm": ["layer_norm", "batch_norm"], + "dropout": [0.1,0.2], + "t_encoding_type": ["fuse", "absolute"], + "rank_topk": [500], + "num_tr_layers": [1], + } + train_search_space = { + "batch_size": [512], + "base_lr": [0.001, 0.01], + "gamma_rate": [0.8,1.0], + } + model_cls = ReRankTransformer + +def train( + model: torch.nn.Module, + optimizer: torch.optim.Optimizer, + loader: NeighborLoader, + train_sparse_tensor: SparseTensor, + epoch:int, +) -> float: + model.train() + + loss_accum = count_accum = 0 + steps = 0 + total_steps = min(len(loader), args.max_steps_per_epoch) + + + pred_list = [] + label_list = [] + + + for batch in tqdm(loader, total=total_steps, desc="Train"): + batch = batch.to(device) + + # Get ground-truth + input_id = batch[task.src_entity_table].input_id + src_batch, dst_index = train_sparse_tensor[input_id] + + # Optimization + optimizer.zero_grad() + + if args.model == "idgnn": + out = model(batch, task.src_entity_table, + task.dst_entity_table).flatten() + batch_size = batch[task.src_entity_table].batch_size + + # Get target label + target = torch.isin( + batch[task.dst_entity_table].batch + + batch_size * batch[task.dst_entity_table].n_id, + src_batch + batch_size * dst_index, + ).float() + + loss = F.binary_cross_entropy_with_logits(out, target) + numel = out.numel() + elif args.model in ["hybridgnn", "shallowrhsgnn"]: + logits = model(batch, task.src_entity_table, task.dst_entity_table) + edge_label_index = torch.stack([src_batch, dst_index], dim=0) + loss = sparse_cross_entropy(logits, edge_label_index) + numel = len(batch[task.dst_entity_table].batch) + elif args.model in ["rhstransformer"]: + logits = model(batch, task.src_entity_table, task.dst_entity_table) + edge_label_index = torch.stack([src_batch, dst_index], dim=0) + loss = sparse_cross_entropy(logits, edge_label_index) + numel = len(batch[task.dst_entity_table].batch) + elif args.model in ["rerank_transformer"]: + numel = len(batch[task.dst_entity_table].batch) + gnn_logits, tr_logits, topk_idx = model(batch, task.src_entity_table, task.dst_entity_table, task.dst_entity_col) + if (epoch <= PRETRAIN_EPOCH): + edge_label_index = torch.stack([src_batch, dst_index], dim=0) + loss = sparse_cross_entropy(gnn_logits, edge_label_index) + else: + #* approach with map_index + label_index, mask = map_index(topk_idx.view(-1), dst_index) + true_label = torch.zeros(topk_idx.shape).to(tr_logits.device) + true_label[mask.view(true_label.shape)] = 1.0 + + + # scores = torch.sigmoid(gnn_logits.detach()) + # _, pred_mini = torch.topk(scores, k=task.eval_k, dim=1) + # pred_list.append(pred_mini) + + + #! to remove + edge_label_index = torch.stack([src_batch, dst_index], dim=0) + loss = sparse_cross_entropy(gnn_logits, edge_label_index) + # # loss = F.binary_cross_entropy_with_logits(tr_logits, true_label.float()) + + #! test for upper bound of transformer + nonzero_mask = torch.any(true_label, dim=1) + tr_logits[nonzero_mask] = true_label[nonzero_mask] + + _, pred_idx = torch.topk(tr_logits, k=task.eval_k, dim=1) + pred_mini = topk_idx[torch.arange(topk_idx.size(0)).unsqueeze(1), pred_idx] + pred_list.append(pred_mini) + + + # # # #* empty label rows + # nonzero_mask = torch.any(true_label, dim=1) + # tr_logits = tr_logits[nonzero_mask] + # true_label = true_label[nonzero_mask] + + loss.backward() + + optimizer.step() + + loss_accum += float(loss) * numel + count_accum += numel + + steps += 1 + if steps > args.max_steps_per_epoch: + break + + pred = torch.cat(pred_list, dim=0).cpu().numpy() + res = task.evaluate(pred, task.get_table('train')) + # label = torch.cat(label_list, dim=0).cpu().numpy() + + # train_map = link_prediction_map(pred, label) + print ("train MAP is", res) + if count_accum == 0: + warnings.warn(f"Did not sample a single '{task.dst_entity_table}' " + f"node in any mini-batch. Try to increase the number " + f"of layers/hops and re-try. If you run into memory " + f"issues with deeper nets, decrease the batch size.") + + return loss_accum / count_accum if count_accum > 0 else float("nan") + + +@torch.no_grad() +def test(model: torch.nn.Module, loader: NeighborLoader, stage: str, epoch:int,) -> float: + model.eval() + + pred_list: List[Tensor] = [] + for batch in tqdm(loader, desc=stage): + batch = batch.to(device) + batch_size = batch[task.src_entity_table].batch_size + + if args.model == "idgnn": + out = (model.forward(batch, task.src_entity_table, + task.dst_entity_table).detach().flatten()) + scores = torch.zeros(batch_size, task.num_dst_nodes, + device=out.device) + scores[batch[task.dst_entity_table].batch, + batch[task.dst_entity_table].n_id] = torch.sigmoid(out) + _, pred_mini = torch.topk(scores, k=task.eval_k, dim=1) + elif args.model in ["hybridgnn", "shallowrhsgnn"]: + # Get ground-truth + out = model(batch, task.src_entity_table, + task.dst_entity_table).detach() + scores = torch.sigmoid(out) + _, pred_mini = torch.topk(scores, k=task.eval_k, dim=1) + elif args.model in ["rhstransformer"]: + out = model(batch, task.src_entity_table, + task.dst_entity_table).detach() + scores = torch.sigmoid(out) + _, pred_mini = torch.topk(scores, k=task.eval_k, dim=1) + elif args.model in ["rerank_transformer"]: + gnn_logits, tr_logits, topk_idx = model(batch, task.src_entity_table, + task.dst_entity_table, + task.dst_entity_col) + + if (epoch <= PRETRAIN_EPOCH): + scores = torch.sigmoid(gnn_logits.detach()) + _, pred_mini = torch.topk(scores, k=task.eval_k, dim=1) + else: + _, pred_idx = torch.topk(torch.sigmoid(tr_logits.detach()), k=min(task.eval_k, tr_logits.shape[1]), dim=1) + pred_mini = topk_idx[torch.arange(topk_idx.size(0)).unsqueeze(1), pred_idx] + + else: + raise ValueError(f"Unsupported model type: {args.model}.") + + pred_list.append(pred_mini) + + pred = torch.cat(pred_list, dim=0).cpu().numpy() + res = task.evaluate(pred, task.get_table(stage)) + return res[LINK_PREDICTION_METRIC] + + +def train_and_eval_with_cfg( + model_cfg: Dict[str, Any], + train_cfg: Dict[str, Any], + trial: Optional[optuna.trial.Trial] = None, +) -> Tuple[float, float]: + loader_dict: Dict[str, NeighborLoader] = {} + dst_nodes_dict: Dict[str, Tuple[NodeType, Tensor]] = {} + num_dst_nodes_dict: Dict[str, int] = {} + for split in ["train", "val", "test"]: + table = task.get_table(split) + table_input = get_link_train_table_input(table, task) + dst_nodes_dict[split] = table_input.dst_nodes + num_dst_nodes_dict[split] = table_input.num_dst_nodes + + #! not shuffle for train + # split == "train" + loader_dict[split] = NeighborLoader( + data, + num_neighbors=num_neighbors, + time_attr="time", + input_nodes=table_input.src_nodes, + input_time=table_input.src_time, + subgraph_type="bidirectional", + batch_size=train_cfg["batch_size"], + temporal_strategy=args.temporal_strategy, + shuffle=False, + num_workers=args.num_workers, + persistent_workers=args.num_workers > 0, + ) + + if args.model in ["hybridgnn", "shallowrhsgnn", "rhstransformer", "rerank_transformer"]: + model_cfg["num_nodes"] = num_dst_nodes_dict["train"] + elif args.model == "idgnn": + model_cfg["out_channels"] = 1 + encoder_model_kwargs = { + "channels": model_cfg["encoder_channels"], + "num_layers": model_cfg["encoder_layers"], + } + model_kwargs = { + k: v + for k, v in model_cfg.items() + if k not in ["encoder_channels", "encoder_layers"] + } + # Use model_cfg to set up training procedure + model = model_cls( + **model_kwargs, + data=data, + col_stats_dict=col_stats_dict, + num_layers=args.num_layers, + torch_frame_model_kwargs=encoder_model_kwargs, + ).to(device) + model.reset_parameters() + # Use train_cfg to set up training procedure + optimizer = torch.optim.Adam(model.parameters(), lr=train_cfg["base_lr"]) + lr_scheduler = ExponentialLR(optimizer, gamma=train_cfg["gamma_rate"]) + + best_val_metric: float = 0.0 + best_test_metric: float = 0.0 + + for epoch in range(1, args.epochs + 1): + train_sparse_tensor = SparseTensor(dst_nodes_dict["train"][1], + device=device) + train_loss = train(model, optimizer, loader_dict["train"], + train_sparse_tensor, epoch) + optimizer.zero_grad() + val_metric = test(model, loader_dict["val"], "val", epoch) + + if val_metric > best_val_metric: + best_val_metric = val_metric + best_test_metric = test(model, loader_dict["test"], "test", epoch) + + lr_scheduler.step() + print(f"Train Loss: {train_loss:.4f}, Val: {val_metric:.4f}") + + if trial is not None: + trial.report(val_metric, epoch) + if trial.should_prune(): + raise optuna.TrialPruned() + + print( + f"Best val: {best_val_metric:.4f}, Best test: {best_test_metric:.4f}") + return best_val_metric, best_test_metric + + +def objective(trial: optuna.trial.Trial) -> float: + model_cfg: Dict[str, Any] = {} + for name, search_list in model_search_space.items(): + assert isinstance(search_list, list) + model_cfg[name] = trial.suggest_categorical(name, search_list) + train_cfg: Dict[str, Any] = {} + for name, search_list in train_search_space.items(): + assert isinstance(search_list, list) + if name == "batch_size": + train_cfg[name] = trial.suggest_categorical(name, search_list) + else: + train_cfg[name] = trial.suggest_loguniform(name, search_list[0], + search_list[1]) + + best_val_metric, _ = train_and_eval_with_cfg(model_cfg=model_cfg, + train_cfg=train_cfg, + trial=trial) + return best_val_metric + + +def main_gnn() -> None: + # Hyper-parameter optimization with Optuna + print("Hyper-parameter search via Optuna") + start_time = time.time() + study = optuna.create_study( + pruner=optuna.pruners.MedianPruner(), + direction="maximize", + ) + study.optimize(objective, n_trials=args.num_trials) + end_time = time.time() + search_time = end_time - start_time + print("Hyper-parameter search done. Found the best config.") + params = study.best_params + best_train_cfg = {} + for train_cfg_key in TRAIN_CONFIG_KEYS: + best_train_cfg[train_cfg_key] = params.pop(train_cfg_key) + best_model_cfg = params + + print(f"Repeat experiments {args.num_repeats} times with the best train " + f"config {best_train_cfg} and model config {best_model_cfg}.") + start_time = time.time() + best_val_metrics = [] + best_test_metrics = [] + for _ in range(args.num_repeats): + best_val_metric, best_test_metric = train_and_eval_with_cfg( + best_model_cfg, best_train_cfg) + best_val_metrics.append(best_val_metric) + best_test_metrics.append(best_test_metric) + end_time = time.time() + final_model_time = (end_time - start_time) / args.num_repeats + best_val_metrics_array = np.array(best_val_metrics) + best_test_metrics_array = np.array(best_test_metrics) + + result_dict = { + "args": args.__dict__, + "best_val_metrics": best_val_metrics_array, + "best_test_metrics": best_test_metrics_array, + "best_val_metric": best_val_metrics_array.mean(), + "best_test_metric": best_test_metrics_array.mean(), + "best_train_cfg": best_train_cfg, + "best_model_cfg": best_model_cfg, + "search_time": search_time, + "final_model_time": final_model_time, + "total_time": search_time + final_model_time, + } + print(result_dict) + # Save results + if args.result_path != "": + os.makedirs(args.result_path, exist_ok=True) + torch.save( + result_dict, + osp.join(args.result_path, + f"{args.dataset}_{args.task}_{args.model}")) + + +if __name__ == "__main__": + print(args) + main_gnn() From dfd559f4ad966318cbcc1f9b3ca65807df65f5ac Mon Sep 17 00:00:00 2001 From: shenyang huang Date: Fri, 30 Aug 2024 18:08:18 +0000 Subject: [PATCH 25/26] adding how we are testing the rerank upper bound --- benchmark/test_max_map.py | 35 +++++++++++------------ hybridgnn/nn/models/rerank_transformer.py | 5 ++++ 2 files changed, 21 insertions(+), 19 deletions(-) diff --git a/benchmark/test_max_map.py b/benchmark/test_max_map.py index 52a4bed..996d504 100644 --- a/benchmark/test_max_map.py +++ b/benchmark/test_max_map.py @@ -45,6 +45,7 @@ PRETRAIN_EPOCH = 0 +use_gnn_pred = False TRAIN_CONFIG_KEYS = ["batch_size", "gamma_rate", "base_lr"] LINK_PREDICTION_METRIC = "link_prediction_map" @@ -168,7 +169,7 @@ "norm": ["layer_norm", "batch_norm"], "dropout": [0.1,0.2], "t_encoding_type": ["fuse", "absolute"], - "rank_topk": [500], + "rank_topk": [100], "num_tr_layers": [1], } train_search_space = { @@ -242,31 +243,27 @@ def train( true_label = torch.zeros(topk_idx.shape).to(tr_logits.device) true_label[mask.view(true_label.shape)] = 1.0 + edge_label_index = torch.stack([src_batch, dst_index], dim=0) + loss = sparse_cross_entropy(gnn_logits, edge_label_index) - # scores = torch.sigmoid(gnn_logits.detach()) - # _, pred_mini = torch.topk(scores, k=task.eval_k, dim=1) - # pred_list.append(pred_mini) + if (use_gnn_pred): + scores = torch.sigmoid(gnn_logits.detach()) + _, pred_mini = torch.topk(scores, k=task.eval_k, dim=1) + pred_list.append(pred_mini) + else: + #! test for upper bound of transformer + nonzero_mask = torch.any(true_label, dim=1) + tr_logits[nonzero_mask] = true_label[nonzero_mask] + _, pred_idx = torch.topk(torch.sigmoid(tr_logits.detach()), k=task.eval_k, dim=1) + pred_mini = topk_idx[torch.arange(topk_idx.size(0)).unsqueeze(1), pred_idx] + pred_list.append(pred_mini) - #! to remove - edge_label_index = torch.stack([src_batch, dst_index], dim=0) - loss = sparse_cross_entropy(gnn_logits, edge_label_index) # # loss = F.binary_cross_entropy_with_logits(tr_logits, true_label.float()) - #! test for upper bound of transformer - nonzero_mask = torch.any(true_label, dim=1) - tr_logits[nonzero_mask] = true_label[nonzero_mask] - - _, pred_idx = torch.topk(tr_logits, k=task.eval_k, dim=1) - pred_mini = topk_idx[torch.arange(topk_idx.size(0)).unsqueeze(1), pred_idx] - pred_list.append(pred_mini) + - # # # #* empty label rows - # nonzero_mask = torch.any(true_label, dim=1) - # tr_logits = tr_logits[nonzero_mask] - # true_label = true_label[nonzero_mask] - loss.backward() optimizer.step() diff --git a/hybridgnn/nn/models/rerank_transformer.py b/hybridgnn/nn/models/rerank_transformer.py index 0fd0f3c..5203488 100644 --- a/hybridgnn/nn/models/rerank_transformer.py +++ b/hybridgnn/nn/models/rerank_transformer.py @@ -185,7 +185,12 @@ def forward( embgnn_logits[lhs_idgnn_batch, rhs_idgnn_index] = idgnn_logits shallow_rhs_embed = rhs_embedding.weight + + _, original_indices = torch.topk(embgnn_logits, self.rank_topk, dim=1) + transformer_logits, topk_index = self.rerank(embgnn_logits, shallow_rhs_embed, rhs_gnn_embedding, rhs_idgnn_index, idgnn_logits, lhs_idgnn_batch,lhs_embedding_projected[lhs_idgnn_batch]) + + assert torch.equal(original_indices, topk_index) return embgnn_logits, transformer_logits, topk_index From 342b90a1a4077583b644ca146302cbf23d46cd74 Mon Sep 17 00:00:00 2001 From: shenyang huang Date: Fri, 30 Aug 2024 21:03:52 +0000 Subject: [PATCH 26/26] commit --- .../relbench_link_prediction_benchmark.py | 20 +++++++++++++++---- benchmark/test_max_map.py | 11 +++------- hybridgnn/nn/models/rerank_transformer.py | 7 ++++--- 3 files changed, 23 insertions(+), 15 deletions(-) diff --git a/benchmark/relbench_link_prediction_benchmark.py b/benchmark/relbench_link_prediction_benchmark.py index 4b85db7..6248f12 100644 --- a/benchmark/relbench_link_prediction_benchmark.py +++ b/benchmark/relbench_link_prediction_benchmark.py @@ -44,7 +44,7 @@ ) -PRETRAIN_EPOCH = 1 +PRETRAIN_EPOCH = 6 TRAIN_CONFIG_KEYS = ["batch_size", "gamma_rate", "base_lr"] LINK_PREDICTION_METRIC = "link_prediction_map" @@ -173,7 +173,7 @@ } train_search_space = { "batch_size": [256, 512], - "base_lr": [0.0005, 0.01], + "base_lr": [0.001,0.002], "gamma_rate": [0.8,1.0], } model_cls = ReRankTransformer @@ -364,9 +364,21 @@ def train_and_eval_with_cfg( torch_frame_model_kwargs=encoder_model_kwargs, ).to(device) model.reset_parameters() + + + # Use train_cfg to set up training procedure optimizer = torch.optim.Adam(model.parameters(), lr=train_cfg["base_lr"]) - lr_scheduler = ExponentialLR(optimizer, gamma=train_cfg["gamma_rate"]) + # lr_scheduler = ExponentialLR(optimizer, gamma=train_cfg["gamma_rate"]) + + # if (args.model == "rerank_transformer"): + # tr_layers, tr_lin = model.get_transformer() + # tr_optimizer + + + + # tr_optimizer = torch.optim.Adam(model.parameters(), lr=train_cfg["base_lr"]) + best_val_metric: float = 0.0 best_test_metric: float = 0.0 @@ -383,7 +395,7 @@ def train_and_eval_with_cfg( best_val_metric = val_metric best_test_metric = test(model, loader_dict["test"], "test", epoch) - lr_scheduler.step() + # lr_scheduler.step() print(f"Train Loss: {train_loss:.4f}, Val: {val_metric:.4f}") if trial is not None: diff --git a/benchmark/test_max_map.py b/benchmark/test_max_map.py index 996d504..f871a77 100644 --- a/benchmark/test_max_map.py +++ b/benchmark/test_max_map.py @@ -322,14 +322,9 @@ def test(model: torch.nn.Module, loader: NeighborLoader, stage: str, epoch:int, gnn_logits, tr_logits, topk_idx = model(batch, task.src_entity_table, task.dst_entity_table, task.dst_entity_col) - - if (epoch <= PRETRAIN_EPOCH): - scores = torch.sigmoid(gnn_logits.detach()) - _, pred_mini = torch.topk(scores, k=task.eval_k, dim=1) - else: - _, pred_idx = torch.topk(torch.sigmoid(tr_logits.detach()), k=min(task.eval_k, tr_logits.shape[1]), dim=1) - pred_mini = topk_idx[torch.arange(topk_idx.size(0)).unsqueeze(1), pred_idx] - + scores = torch.sigmoid(gnn_logits.detach()) + _, pred_mini = torch.topk(scores, k=task.eval_k, dim=1) + else: raise ValueError(f"Unsupported model type: {args.model}.") diff --git a/hybridgnn/nn/models/rerank_transformer.py b/hybridgnn/nn/models/rerank_transformer.py index 5203488..ee51023 100644 --- a/hybridgnn/nn/models/rerank_transformer.py +++ b/hybridgnn/nn/models/rerank_transformer.py @@ -109,6 +109,10 @@ def __init__( self.reset_parameters() + # def get_transformer(self) -> list: + # return [self.tr_blocks, self.tr_lin] + + def reset_parameters(self) -> None: self.encoder.reset_parameters() self.temporal_encoder.reset_parameters() @@ -186,11 +190,8 @@ def forward( shallow_rhs_embed = rhs_embedding.weight - _, original_indices = torch.topk(embgnn_logits, self.rank_topk, dim=1) - transformer_logits, topk_index = self.rerank(embgnn_logits, shallow_rhs_embed, rhs_gnn_embedding, rhs_idgnn_index, idgnn_logits, lhs_idgnn_batch,lhs_embedding_projected[lhs_idgnn_batch]) - assert torch.equal(original_indices, topk_index) return embgnn_logits, transformer_logits, topk_index