Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ReRank Transformer #22

Draft
wants to merge 27 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
829fee5
adding initial implementation of rhs transformer to repo
andyhuang-kumo Aug 8, 2024
12a83bd
adding rhs transformer to the benchmark script
andyhuang-kumo Aug 9, 2024
a3a9779
rhs transformer upload
andyhuang-kumo Aug 12, 2024
988781b
updating tr
andyhuang-kumo Aug 12, 2024
e5faa22
running code
andyhuang-kumo Aug 12, 2024
5e904ea
running code
andyhuang-kumo Aug 12, 2024
d6ff169
adding transformer changes
andyhuang-kumo Aug 13, 2024
4884b83
permute the index, the rhs and then reverse it
andyhuang-kumo Aug 13, 2024
08041cb
removing none to replace with None
andyhuang-kumo Aug 13, 2024
29c70ad
add time fuse encoder to extract time pe
andyhuang-kumo Aug 14, 2024
f58b528
update hyperparameter options
andyhuang-kumo Aug 14, 2024
8f4966c
adding rerank_transformer
andyhuang-kumo Aug 19, 2024
90a8817
setting zeros for not used logits in rerank transformer
andyhuang-kumo Aug 19, 2024
1eb769c
adding reranker transformer
andyhuang-kumo Aug 21, 2024
ea58dd0
Merge branch 'master' into rhs_tr
andyhuang-kumo Aug 21, 2024
75445ce
updating RHS transformer code
andyhuang-kumo Aug 21, 2024
7380a17
updating rerank_transformer
andyhuang-kumo Aug 23, 2024
7622b94
push current version
andyhuang-kumo Aug 23, 2024
dc73f7c
converging to the follow implementation
andyhuang-kumo Aug 23, 2024
38f9cf4
training both from epoch 0, potentially try better ways
andyhuang-kumo Aug 23, 2024
8e846dd
for transformer, not training with nodes whose prediction is not within
andyhuang-kumo Aug 25, 2024
41a38cf
commiting before clean up
andyhuang-kumo Aug 26, 2024
a181f45
semi working version of rerank transformer, still needs work for
andyhuang-kumo Aug 27, 2024
b86641f
somewhat stable version of the idea
andyhuang-kumo Aug 28, 2024
66db12f
adding test for max_map
andyhuang-kumo Aug 30, 2024
dfd559f
adding how we are testing the rerank upper bound
andyhuang-kumo Aug 30, 2024
342b90a
commit
andyhuang-kumo Aug 30, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,4 @@ coverage.xml
venv/*
*.out
data/**
*.txt
10 changes: 7 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,17 @@
Run [`benchmark/relbench_link_prediction_benchmark.py`](https://github.com/kumo-ai/hybridgnn/blob/master/benchmark/relbench_link_prediction_benchmark.py)

```sh
python relbench_link_prediction_benchmark.py --dataset rel-trial --task site-sponsor-run --model hybridgnn
python relbench_link_prediction_benchmark.py --dataset rel-stack --task user-post-comment --model rhstransformer --num_trials 10
python relbench_link_prediction_benchmark.py --dataset rel-hm --task user-item-purcahse --model rhstransformer
python relbench_link_prediction_benchmark.py --dataset rel-trial --task site-sponsor-run --model hybridgnn --num_layers 4
```


Run [`examples/relbench_example.py`](https://github.com/kumo-ai/hybridgnn/blob/master/examples/relbench_example.py)

```sh
python relbench_example.py --dataset rel-trial --task site-sponsor-run --model hybridgnn
python relbench_example.py --dataset rel-trial --task condition-sponsor-run --model hybridgnn
python relbench_example.py --dataset rel-trial --task site-sponsor-run --model hybridgnn --num_layers 4
python relbench_example.py --dataset rel-trial --task condition-sponsor-run --model hybridgnn --num_layers 4
```


Expand All @@ -31,4 +33,6 @@ pip install -e .

# to run examples and benchmarks
pip install -e '.[full]'

pip install -U sentence-transformers
```
126 changes: 114 additions & 12 deletions benchmark/relbench_link_prediction_benchmark.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
"""
$ python relbench_link_prediction_benchmark.py --dataset rel-stack --task post-post-related --model rhstransformer --num_trials 10
"""

import argparse
import json
import os
Expand Down Expand Up @@ -30,8 +34,17 @@
from torch_geometric.utils.cross_entropy import sparse_cross_entropy
from tqdm import tqdm

from hybridgnn.nn.models import IDGNN, HybridGNN, ShallowRHSGNN
from hybridgnn.nn.models import IDGNN, HybridGNN, ShallowRHSGNN, RHSTransformer, ReRankTransformer
from hybridgnn.utils import GloveTextEmbedding
from torch_geometric.utils import index_to_mask
from torch_geometric.utils.map import map_index

from relbench.metrics import (
link_prediction_map,
)


PRETRAIN_EPOCH = 6

TRAIN_CONFIG_KEYS = ["batch_size", "gamma_rate", "base_lr"]
LINK_PREDICTION_METRIC = "link_prediction_map"
Expand All @@ -43,7 +56,7 @@
"--model",
type=str,
default="hybridgnn",
choices=["hybridgnn", "idgnn", "shallowrhsgnn"],
choices=["hybridgnn", "idgnn", "shallowrhsgnn", "rhstransformer", "rerank_transformer"],
)
parser.add_argument("--epochs", type=int, default=20)
parser.add_argument("--num_trials", type=int, default=10,
Expand Down Expand Up @@ -102,7 +115,7 @@
int(args.num_neighbors // 2**i) for i in range(args.num_layers)
]

model_cls: Type[Union[IDGNN, HybridGNN, ShallowRHSGNN]]
model_cls: Type[Union[IDGNN, HybridGNN, ShallowRHSGNN, RHSTransformer, ReRankTransformer]]

if args.model == "idgnn":
model_search_space = {
Expand Down Expand Up @@ -131,19 +144,53 @@
"gamma_rate": [0.8, 1.],
}
model_cls = (HybridGNN if args.model == "hybridgnn" else ShallowRHSGNN)

elif args.model in ["rhstransformer"]:
model_search_space = {
"encoder_channels": [64, 128],
"encoder_layers": [2, 4],
"channels": [64, 128],
"embedding_dim": [64, 128],
"norm": ["layer_norm", "batch_norm"],
"dropout": [0.1, 0.2],
"t_encoding_type": ["fuse", "absolute"],
}
train_search_space = {
"batch_size": [128, 256],
"base_lr": [0.0005, 0.01],
"gamma_rate": [0.8, 1.0],
}
model_cls = RHSTransformer
elif args.model in ["rerank_transformer"]:
model_search_space = {
"encoder_channels": [64, 128, 256],
"encoder_layers": [2, 4, 8],
"channels": [64,128],
"norm": ["layer_norm", "batch_norm"],
"dropout": [0.0,0.1,0.2],
"t_encoding_type": ["fuse", "absolute"],
"rank_topk": [100,150,200],
"num_tr_layers": [1,2],
}
train_search_space = {
"batch_size": [256, 512],
"base_lr": [0.001,0.002],
"gamma_rate": [0.8,1.0],
}
model_cls = ReRankTransformer

def train(
model: torch.nn.Module,
optimizer: torch.optim.Optimizer,
loader: NeighborLoader,
train_sparse_tensor: SparseTensor,
epoch:int,
) -> float:
model.train()

loss_accum = count_accum = 0
steps = 0
total_steps = min(len(loader), args.max_steps_per_epoch)

for batch in tqdm(loader, total=total_steps, desc="Train"):
batch = batch.to(device)

Expand Down Expand Up @@ -173,6 +220,30 @@ def train(
edge_label_index = torch.stack([src_batch, dst_index], dim=0)
loss = sparse_cross_entropy(logits, edge_label_index)
numel = len(batch[task.dst_entity_table].batch)
elif args.model in ["rhstransformer"]:
logits = model(batch, task.src_entity_table, task.dst_entity_table)
edge_label_index = torch.stack([src_batch, dst_index], dim=0)
loss = sparse_cross_entropy(logits, edge_label_index)
numel = len(batch[task.dst_entity_table].batch)
elif args.model in ["rerank_transformer"]:
numel = len(batch[task.dst_entity_table].batch)
gnn_logits, tr_logits, topk_idx = model(batch, task.src_entity_table, task.dst_entity_table, task.dst_entity_col)
if (epoch <= PRETRAIN_EPOCH):
edge_label_index = torch.stack([src_batch, dst_index], dim=0)
loss = sparse_cross_entropy(gnn_logits, edge_label_index)
else:
#* approach with map_index
label_index, mask = map_index(topk_idx.view(-1), dst_index)
true_label = torch.zeros(topk_idx.shape).to(tr_logits.device)
true_label[mask.view(true_label.shape)] = 1.0

# # # #* empty label rows
# nonzero_mask = torch.any(true_label, dim=1)
# tr_logits = tr_logits[nonzero_mask]
# true_label = true_label[nonzero_mask]

loss = F.binary_cross_entropy_with_logits(tr_logits, true_label.float())

loss.backward()

optimizer.step()
Expand All @@ -194,7 +265,7 @@ def train(


@torch.no_grad()
def test(model: torch.nn.Module, loader: NeighborLoader, stage: str) -> float:
def test(model: torch.nn.Module, loader: NeighborLoader, stage: str, epoch:int,) -> float:
model.eval()

pred_list: List[Tensor] = []
Expand All @@ -209,15 +280,33 @@ def test(model: torch.nn.Module, loader: NeighborLoader, stage: str) -> float:
device=out.device)
scores[batch[task.dst_entity_table].batch,
batch[task.dst_entity_table].n_id] = torch.sigmoid(out)
_, pred_mini = torch.topk(scores, k=task.eval_k, dim=1)
elif args.model in ["hybridgnn", "shallowrhsgnn"]:
# Get ground-truth
out = model(batch, task.src_entity_table,
task.dst_entity_table).detach()
scores = torch.sigmoid(out)
_, pred_mini = torch.topk(scores, k=task.eval_k, dim=1)
elif args.model in ["rhstransformer"]:
out = model(batch, task.src_entity_table,
task.dst_entity_table).detach()
scores = torch.sigmoid(out)
_, pred_mini = torch.topk(scores, k=task.eval_k, dim=1)
elif args.model in ["rerank_transformer"]:
gnn_logits, tr_logits, topk_idx = model(batch, task.src_entity_table,
task.dst_entity_table,
task.dst_entity_col)

if (epoch <= PRETRAIN_EPOCH):
scores = torch.sigmoid(gnn_logits.detach())
_, pred_mini = torch.topk(scores, k=task.eval_k, dim=1)
else:
_, pred_idx = torch.topk(torch.sigmoid(tr_logits.detach()), k=task.eval_k, dim=1)
pred_mini = topk_idx[torch.arange(topk_idx.size(0)).unsqueeze(1), pred_idx]

else:
raise ValueError(f"Unsupported model type: {args.model}.")

_, pred_mini = torch.topk(scores, k=task.eval_k, dim=1)
pred_list.append(pred_mini)

pred = torch.cat(pred_list, dim=0).cpu().numpy()
Expand All @@ -238,6 +327,7 @@ def train_and_eval_with_cfg(
table_input = get_link_train_table_input(table, task)
dst_nodes_dict[split] = table_input.dst_nodes
num_dst_nodes_dict[split] = table_input.num_dst_nodes

loader_dict[split] = NeighborLoader(
data,
num_neighbors=num_neighbors,
Expand All @@ -252,7 +342,7 @@ def train_and_eval_with_cfg(
persistent_workers=args.num_workers > 0,
)

if args.model in ["hybridgnn", "shallowrhsgnn"]:
if args.model in ["hybridgnn", "shallowrhsgnn", "rhstransformer", "rerank_transformer"]:
model_cfg["num_nodes"] = num_dst_nodes_dict["train"]
elif args.model == "idgnn":
model_cfg["out_channels"] = 1
Expand All @@ -274,9 +364,21 @@ def train_and_eval_with_cfg(
torch_frame_model_kwargs=encoder_model_kwargs,
).to(device)
model.reset_parameters()



# Use train_cfg to set up training procedure
optimizer = torch.optim.Adam(model.parameters(), lr=train_cfg["base_lr"])
lr_scheduler = ExponentialLR(optimizer, gamma=train_cfg["gamma_rate"])
# lr_scheduler = ExponentialLR(optimizer, gamma=train_cfg["gamma_rate"])

# if (args.model == "rerank_transformer"):
# tr_layers, tr_lin = model.get_transformer()
# tr_optimizer



# tr_optimizer = torch.optim.Adam(model.parameters(), lr=train_cfg["base_lr"])


best_val_metric: float = 0.0
best_test_metric: float = 0.0
Expand All @@ -285,15 +387,15 @@ def train_and_eval_with_cfg(
train_sparse_tensor = SparseTensor(dst_nodes_dict["train"][1],
device=device)
train_loss = train(model, optimizer, loader_dict["train"],
train_sparse_tensor)
train_sparse_tensor, epoch)
optimizer.zero_grad()
val_metric = test(model, loader_dict["val"], "val")
val_metric = test(model, loader_dict["val"], "val", epoch)

if val_metric > best_val_metric:
best_val_metric = val_metric
best_test_metric = test(model, loader_dict["test"], "test")
best_test_metric = test(model, loader_dict["test"], "test", epoch)

lr_scheduler.step()
# lr_scheduler.step()
print(f"Train Loss: {train_loss:.4f}, Val: {val_metric:.4f}")

if trial is not None:
Expand Down
Loading