diff --git a/examples/baseline/__init__.py b/examples/baseline/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/examples/baseline/model.py b/examples/baseline/model.py new file mode 100644 index 0000000..fc5755f --- /dev/null +++ b/examples/baseline/model.py @@ -0,0 +1,152 @@ +from typing import Any, Dict, List + +import torch +from torch import Tensor +from torch.nn import Embedding, ModuleDict +from torch_frame.data.stats import StatType +from torch_geometric.data import HeteroData +from torch_geometric.nn import MLP +from torch_geometric.typing import NodeType + +from .nn import HeteroEncoder, HeteroGraphSAGE, HeteroTemporalEncoder + + +class Model(torch.nn.Module): + def __init__( + self, + data: HeteroData, + col_stats_dict: Dict[str, Dict[str, Dict[StatType, Any]]], + num_layers: int, + channels: int, + out_channels: int, + num_src_nodes: int, + num_dst_nodes: int, + src_entity_table: str, + dst_entity_table: str, + aggr: str, + norm: str, + # List of node types to add shallow embeddings to input + shallow_list: List[NodeType] = [], + # ID awareness + id_awareness: bool = False, + ): + super().__init__() + + self.src_entity_table = src_entity_table + self.dst_entity_table = dst_entity_table + self.encoder = HeteroEncoder( + channels=channels, + src_entity_table=src_entity_table, + dst_entity_table=dst_entity_table, + num_src_nodes=num_src_nodes, + num_dst_nodes=num_dst_nodes, + node_to_col_names_dict={ + node_type: data[node_type].tf.col_names_dict + for node_type in data.node_types + }, + node_to_col_stats=col_stats_dict, + ) + self.temporal_encoder = HeteroTemporalEncoder( + node_types=[ + node_type for node_type in data.node_types + if "time" in data[node_type] + ], + channels=channels, + ) + self.gnn = HeteroGraphSAGE( + node_types=data.node_types, + edge_types=data.edge_types, + channels=channels, + aggr=aggr, + num_layers=num_layers, + ) + self.head = MLP( + channels, + out_channels=out_channels, + norm=norm, + num_layers=1, + ) + self.embedding_dict = ModuleDict({ + node: + Embedding(data.num_nodes_dict[node], channels) + for node in shallow_list + }) + + self.id_awareness_emb = None + if id_awareness: + self.id_awareness_emb = torch.nn.Embedding(1, channels) + self.reset_parameters() + + def reset_parameters(self): + self.encoder.reset_parameters() + self.temporal_encoder.reset_parameters() + self.gnn.reset_parameters() + self.head.reset_parameters() + for embedding in self.embedding_dict.values(): + torch.nn.init.normal_(embedding.weight, std=0.1) + if self.id_awareness_emb is not None: + self.id_awareness_emb.reset_parameters() + + def forward( + self, + batch: HeteroData, + entity_table: NodeType, + ) -> Tensor: + seed_time = batch[entity_table].seed_time + ####################################################################### + x_dict = self.encoder(batch.tf_dict, batch) + ####################################################################### + + rel_time_dict = self.temporal_encoder(seed_time, batch.time_dict, + batch.batch_dict) + + for node_type, rel_time in rel_time_dict.items(): + x_dict[node_type] = x_dict[node_type] + rel_time + + for node_type, embedding in self.embedding_dict.items(): + ################################################################### + if node_type not in {self.src_entity_table, self.dst_entity_table}: + x_dict[node_type] = x_dict[node_type] + embedding( + batch[node_type].n_id) + ################################################################### + + x_dict = self.gnn( + x_dict, + batch.edge_index_dict, + batch.num_sampled_nodes_dict, + batch.num_sampled_edges_dict, + ) + + return self.head(x_dict[entity_table][:seed_time.size(0)]) + + def forward_dst_readout( + self, + batch: HeteroData, + entity_table: NodeType, + dst_table: NodeType, + ) -> Tensor: + if self.id_awareness_emb is None: + raise RuntimeError( + "id_awareness must be set True to use forward_dst_readout") + seed_time = batch[entity_table].seed_time + x_dict = self.encoder(batch.tf_dict) + # Add ID-awareness to the root node + x_dict[entity_table][:seed_time.size(0 + )] += self.id_awareness_emb.weight + + rel_time_dict = self.temporal_encoder(seed_time, batch.time_dict, + batch.batch_dict) + + for node_type, rel_time in rel_time_dict.items(): + x_dict[node_type] = x_dict[node_type] + rel_time + + for node_type, embedding in self.embedding_dict.items(): + x_dict[node_type] = x_dict[node_type] + embedding( + batch[node_type].n_id) + + x_dict = self.gnn( + x_dict, + batch.edge_index_dict, + ) + + return self.head(x_dict[dst_table]) diff --git a/examples/baseline/nn.py b/examples/baseline/nn.py new file mode 100644 index 0000000..16e9144 --- /dev/null +++ b/examples/baseline/nn.py @@ -0,0 +1,228 @@ +from typing import Any, Dict, List, Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch_frame +from torch import Tensor +from torch_frame.data.stats import StatType +from torch_frame.nn.models import ResNet +from torch_geometric.data import HeteroData +from torch_geometric.nn import ( + HeteroConv, + LayerNorm, + PositionalEncoding, + SAGEConv, +) +from torch_geometric.typing import EdgeType, NodeType + + +class HeteroEncoder(torch.nn.Module): + r"""HeteroEncoder based on PyTorch Frame. + + Args: + channels (int): The output channels for each node type. + src_entity_table (str): Source entity table name. + dst_entity_table (str): Destination entity table name. + num_src_nodes (int): Number of source nodes. + num_dst_nodes (int): Number of destination nodes. + node_to_col_names_dict + (Dict[NodeType, Dict[torch_frame.stype, List[str]]]): + A dictionary mapping from node type to column names dictionary + compatible to PyTorch Frame. + torch_frame_model_cls: Model class for PyTorch Frame. The class object + takes :class:`TensorFrame` object as input and outputs + :obj:`channels`-dimensional embeddings. Default to + :class:`torch_frame.nn.ResNet`. + torch_frame_model_kwargs (Dict[str, Any]): Keyword arguments for + :class:`torch_frame_model_cls` class. Default keyword argument is + set specific for :class:`torch_frame.nn.ResNet`. Expect it to + be changed for different :class:`torch_frame_model_cls`. + default_stype_encoder_cls_kwargs (Dict[torch_frame.stype, Any]): + A dictionary mapping from :obj:`torch_frame.stype` object into a + tuple specifying :class:`torch_frame.nn.StypeEncoder` class and its + keyword arguments :obj:`kwargs`. + """ + def __init__( + self, + channels: int, + ####################################################################### + src_entity_table: str, + dst_entity_table: str, + num_src_nodes: int, + num_dst_nodes: int, + ####################################################################### + node_to_col_names_dict: Dict[NodeType, Dict[torch_frame.stype, + List[str]]], + node_to_col_stats: Dict[NodeType, Dict[str, Dict[StatType, Any]]], + torch_frame_model_cls=ResNet, + torch_frame_model_kwargs: Dict[str, Any] = { + "channels": 128, + "num_layers": 4, + }, + default_stype_encoder_cls_kwargs: Dict[torch_frame.stype, Any] = { + torch_frame.categorical: (torch_frame.nn.EmbeddingEncoder, {}), + torch_frame.numerical: (torch_frame.nn.LinearEncoder, {}), + torch_frame.multicategorical: ( + torch_frame.nn.MultiCategoricalEmbeddingEncoder, + {}, + ), + torch_frame.embedding: (torch_frame.nn.LinearEmbeddingEncoder, {}), + torch_frame.timestamp: (torch_frame.nn.TimestampEncoder, {}), + }, + ): + super().__init__() + + ####################################################################### + self.src_entity_table = src_entity_table + self.dst_entity_table = dst_entity_table + ####################################################################### + + self.encoders = torch.nn.ModuleDict() + + for node_type in node_to_col_names_dict.keys(): + ################################################################### + if node_type == self.src_entity_table: + self.encoders[node_type] = nn.Embedding( + num_src_nodes, channels) + elif node_type == self.dst_entity_table: + self.encoders[node_type] = nn.Embedding( + num_dst_nodes, channels) + ################################################################### + else: + stype_encoder_dict = { + stype: + default_stype_encoder_cls_kwargs[stype][0]( + **default_stype_encoder_cls_kwargs[stype][1]) + for stype in node_to_col_names_dict[node_type].keys() + } + torch_frame_model = torch_frame_model_cls( + **torch_frame_model_kwargs, + out_channels=channels, + col_stats=node_to_col_stats[node_type], + col_names_dict=node_to_col_names_dict[node_type], + stype_encoder_dict=stype_encoder_dict, + ) + self.encoders[node_type] = torch_frame_model + + def reset_parameters(self): + for node_type, encoder in self.encoders.items(): + ################################################################### + if node_type in {self.src_entity_table, self.dst_entity_table}: + nn.init.xavier_uniform_(encoder.weight) + ################################################################### + else: + encoder.reset_parameters() + + def forward( + self, + tf_dict: Dict[NodeType, torch_frame.TensorFrame], + batch: HeteroData, + ) -> Dict[NodeType, Tensor]: + x_dict = {} + for node_type, tf in tf_dict.items(): + if node_type not in {self.src_entity_table, self.dst_entity_table}: + x_dict[node_type] = self.encoders[node_type](tf) + else: + ############################################################### + x_dict[node_type] = self.encoders[node_type]( + batch[node_type].n_id) + ############################################################### + return x_dict + + +class HeteroTemporalEncoder(torch.nn.Module): + def __init__(self, node_types: List[NodeType], channels: int): + super().__init__() + + self.encoder_dict = torch.nn.ModuleDict({ + node_type: + PositionalEncoding(channels) + for node_type in node_types + }) + self.lin_dict = torch.nn.ModuleDict({ + node_type: + torch.nn.Linear(channels, channels) + for node_type in node_types + }) + + def reset_parameters(self): + for encoder in self.encoder_dict.values(): + encoder.reset_parameters() + for lin in self.lin_dict.values(): + lin.reset_parameters() + + def forward( + self, + seed_time: Tensor, + time_dict: Dict[NodeType, Tensor], + batch_dict: Dict[NodeType, Tensor], + ) -> Dict[NodeType, Tensor]: + out_dict: Dict[NodeType, Tensor] = {} + + for node_type, time in time_dict.items(): + rel_time = seed_time[batch_dict[node_type]] - time + rel_time = rel_time / (60 * 60 * 24) # Convert seconds to days. + + x = self.encoder_dict[node_type](rel_time) + x = self.lin_dict[node_type](x) + out_dict[node_type] = x + + return out_dict + + +class HeteroGraphSAGE(torch.nn.Module): + def __init__( + self, + node_types: List[NodeType], + edge_types: List[EdgeType], + channels: int, + aggr: str = "mean", + num_layers: int = 2, + ): + super().__init__() + + self.convs = torch.nn.ModuleList() + for _ in range(num_layers): + conv = HeteroConv( + { + edge_type: SAGEConv( + (channels, channels), channels, aggr=aggr) + for edge_type in edge_types + }, + aggr="sum", + ) + self.convs.append(conv) + + self.norms = torch.nn.ModuleList() + for _ in range(num_layers): + norm_dict = torch.nn.ModuleDict() + for node_type in node_types: + norm_dict[node_type] = LayerNorm(channels, mode="node") + self.norms.append(norm_dict) + + def reset_parameters(self): + for conv in self.convs: + conv.reset_parameters() + for norm_dict in self.norms: + for norm in norm_dict.values(): + norm.reset_parameters() + + def forward( + self, + x_dict: Dict[NodeType, Tensor], + edge_index_dict: Dict[NodeType, Tensor], + num_sampled_nodes_dict: Optional[Dict[NodeType, List[int]]] = None, + num_sampled_edges_dict: Optional[Dict[EdgeType, List[int]]] = None, + ) -> Dict[NodeType, Tensor]: + for _, (conv, norm_dict) in enumerate(zip(self.convs, self.norms)): + x_dict = conv(x_dict, edge_index_dict) + x_dict = {key: norm_dict[key](x) for key, x in x_dict.items()} + ################################################################### + x_dict = { + key: F.leaky_relu(x, negative_slope=0.2) + for key, x in x_dict.items() + } + ################################################################### + + return x_dict diff --git a/examples/gnn_link.py b/examples/gnn_link.py new file mode 100644 index 0000000..7a789c8 --- /dev/null +++ b/examples/gnn_link.py @@ -0,0 +1,257 @@ +from __future__ import annotations + +import argparse +import copy +import json +import os +import warnings +from pathlib import Path +from typing import Dict, Tuple + +import numpy as np +import torch +import torch.nn.functional as F +from baseline.model import Model +from relbench.base import Dataset, RecommendationTask, TaskType +from relbench.datasets import get_dataset +from relbench.modeling.graph import ( + get_link_train_table_input, + make_pkey_fkey_graph, +) +from relbench.modeling.loader import LinkNeighborLoader +from relbench.modeling.utils import get_stype_proposal +from relbench.tasks import get_task +from torch import Tensor +from torch_frame import stype +from torch_frame.config.text_embedder import TextEmbedderConfig +from torch_geometric.loader import NeighborLoader +from torch_geometric.seed import seed_everything +from tqdm import tqdm + +from hybridgnn.utils import GloveTextEmbedding + +parser = argparse.ArgumentParser() +parser.add_argument("--dataset", type=str, default="rel-hm") +parser.add_argument("--task", type=str, default="user-item-purchase") +parser.add_argument("--lr", type=float, default=0.001) +parser.add_argument("--epochs", type=int, default=20) +parser.add_argument("--eval_epochs_interval", type=int, default=1) +parser.add_argument("--batch_size", type=int, default=512) +parser.add_argument("--channels", type=int, default=128) +parser.add_argument("--aggr", type=str, default="sum") +parser.add_argument("--num_layers", type=int, default=2) +parser.add_argument("--num_neighbors", type=int, default=128) +parser.add_argument("--temporal_strategy", type=str, default="uniform") +# Use the same seed time across the mini-batch and share the negatives +parser.add_argument("--share_same_time", action="store_true", default=True) +parser.add_argument("--no-share_same_time", dest="share_same_time", + action="store_false") +# Whether to use shallow embedding on dst nodes or not. +parser.add_argument("--use_shallow", action="store_true", default=True) +parser.add_argument("--no-use_shallow", dest="use_shallow", + action="store_false") +parser.add_argument("--max_steps_per_epoch", type=int, default=2000) +parser.add_argument("--num_workers", type=int, default=0) +parser.add_argument("--seed", type=int, default=42) +parser.add_argument( + "--cache_dir", + type=str, + default=os.path.expanduser("~/.cache/relbench_examples"), +) +args = parser.parse_args() + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +if torch.cuda.is_available(): + torch.set_num_threads(1) +seed_everything(args.seed) + +dataset: Dataset = get_dataset(args.dataset, download=True) +task: RecommendationTask = get_task(args.dataset, args.task, download=True) +tune_metric = "link_prediction_map" +assert task.task_type == TaskType.LINK_PREDICTION + +stypes_cache_path = Path(f"{args.cache_dir}/{args.dataset}/stypes.json") +try: + with open(stypes_cache_path, "r") as f: + col_to_stype_dict = json.load(f) + for table, col_to_stype in col_to_stype_dict.items(): + for col, stype_str in col_to_stype.items(): + col_to_stype[col] = stype(stype_str) +except FileNotFoundError: + col_to_stype_dict = get_stype_proposal(dataset.get_db()) + Path(stypes_cache_path).parent.mkdir(parents=True, exist_ok=True) + with open(stypes_cache_path, "w") as f: + json.dump(col_to_stype_dict, f, indent=2, default=str) + +data, col_stats_dict = make_pkey_fkey_graph( + dataset.get_db(), + col_to_stype_dict=col_to_stype_dict, + text_embedder_cfg=TextEmbedderConfig( + text_embedder=GloveTextEmbedding(device=device), batch_size=256), + cache_dir=f"{args.cache_dir}/{args.dataset}/materialized", +) + +num_neighbors = [ + int(args.num_neighbors // 2**i) for i in range(args.num_layers) +] + +train_table_input = get_link_train_table_input(task.get_table("train"), task) +train_loader = LinkNeighborLoader( + data=data, + num_neighbors=num_neighbors, + time_attr="time", + src_nodes=train_table_input.src_nodes, + dst_nodes=train_table_input.dst_nodes, + num_dst_nodes=train_table_input.num_dst_nodes, + src_time=train_table_input.src_time, + share_same_time=args.share_same_time, + batch_size=args.batch_size, + temporal_strategy=args.temporal_strategy, + # if share_same_time is True, we use sampler, so shuffle must be set False + shuffle=not args.share_same_time, + num_workers=args.num_workers, +) + +eval_loaders_dict: Dict[str, Tuple[NeighborLoader, NeighborLoader]] = {} +for split in ["val", "test"]: + timestamp = dataset.val_timestamp if split == "val" else \ + dataset.test_timestamp + seed_time = int(timestamp.timestamp()) + target_table = task.get_table(split) + src_node_indices = torch.from_numpy( + target_table.df[task.src_entity_col].values) + src_loader = NeighborLoader( + data, + num_neighbors=num_neighbors, + time_attr="time", + input_nodes=(task.src_entity_table, src_node_indices), + input_time=torch.full(size=(len(src_node_indices), ), + fill_value=seed_time, dtype=torch.long), + batch_size=args.batch_size, + shuffle=False, + num_workers=args.num_workers, + ) + dst_loader = NeighborLoader( + data, + num_neighbors=num_neighbors, + time_attr="time", + input_nodes=task.dst_entity_table, + input_time=torch.full(size=(task.num_dst_nodes, ), + fill_value=seed_time, dtype=torch.long), + batch_size=args.batch_size, + shuffle=False, + num_workers=args.num_workers, + ) + eval_loaders_dict[split] = (src_loader, dst_loader) + +model = Model( + data=data, + col_stats_dict=col_stats_dict, + num_layers=args.num_layers, + channels=args.channels, + num_src_nodes=task.num_src_nodes, + num_dst_nodes=task.num_dst_nodes, + src_entity_table=task.src_entity_table, + dst_entity_table=task.dst_entity_table, + out_channels=args.channels, + aggr=args.aggr, + norm="layer_norm", + shallow_list=[task.dst_entity_table] if args.use_shallow else [], +).to(device) +optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) + + +def train() -> float: + model.train() + + loss_accum = count_accum = 0 + steps = 0 + total_steps = min(len(train_loader), args.max_steps_per_epoch) + for batch in tqdm(train_loader, total=total_steps): + src_batch, batch_pos_dst, batch_neg_dst = batch + src_batch, batch_pos_dst, batch_neg_dst = ( + src_batch.to(device), + batch_pos_dst.to(device), + batch_neg_dst.to(device), + ) + x_src = model(src_batch, task.src_entity_table) + x_pos_dst = model(batch_pos_dst, task.dst_entity_table) + x_neg_dst = model(batch_neg_dst, task.dst_entity_table) + + # [batch_size, ] + pos_score = torch.sum(x_src * x_pos_dst, dim=1) + if args.share_same_time: + # [batch_size, batch_size] + neg_score = x_src @ x_neg_dst.t() + # [batch_size, 1] + pos_score = pos_score.view(-1, 1) + else: + # [batch_size, ] + neg_score = torch.sum(x_src * x_neg_dst, dim=1) + optimizer.zero_grad() + # BPR loss + diff_score = pos_score - neg_score + loss = F.softplus(-diff_score).mean() + loss.backward() + optimizer.step() + + loss_accum += float(loss) * x_src.size(0) + count_accum += x_src.size(0) + + steps += 1 + if steps > args.max_steps_per_epoch: + break + + if count_accum == 0: + warnings.warn(f"Did not sample a single '{task.dst_entity_table}' " + f"node in any mini-batch. Try to increase the number " + f"of layers/hops and re-try. If you run into memory " + f"issues with deeper nets, decrease the batch size.") + + return loss_accum / count_accum if count_accum > 0 else float("nan") + + +@torch.no_grad() +def test(src_loader: NeighborLoader, dst_loader: NeighborLoader) -> np.ndarray: + model.eval() + + dst_embs: list[Tensor] = [] + for batch in tqdm(dst_loader): + batch = batch.to(device) + emb = model(batch, task.dst_entity_table).detach() + dst_embs.append(emb) + dst_emb = torch.cat(dst_embs, dim=0) + del dst_embs + + pred_index_mat_list: list[Tensor] = [] + for batch in tqdm(src_loader): + batch = batch.to(device) + emb = model(batch, task.src_entity_table) + _, pred_index_mat = torch.topk(emb @ dst_emb.t(), k=task.eval_k, dim=1) + pred_index_mat_list.append(pred_index_mat.cpu()) + pred = torch.cat(pred_index_mat_list, dim=0).numpy() + return pred + + +state_dict = None +best_val_metric = 0 +for epoch in range(1, args.epochs + 1): + train_loss = train() + if epoch % args.eval_epochs_interval == 0: + val_pred = test(*eval_loaders_dict["val"]) + val_metrics = task.evaluate(val_pred, task.get_table("val")) + print(f"Epoch: {epoch:02d}, Train loss: {train_loss}, " + f"Val metrics: {val_metrics}") + + if val_metrics[tune_metric] >= best_val_metric: + best_val_metric = val_metrics[tune_metric] + state_dict = copy.deepcopy(model.state_dict()) + +model.load_state_dict(state_dict) # type: ignore +val_pred = test(*eval_loaders_dict["val"]) +val_metrics = task.evaluate(val_pred, task.get_table("val")) +print(f"Best Val metrics: {val_metrics}") + +test_pred = test(*eval_loaders_dict["test"]) +test_metrics = task.evaluate(test_pred) +print(f"Best test metrics: {test_metrics}")