diff --git a/examples/lpformer.py b/examples/lpformer.py index afd06eda0b7ce..65acdf53d3e21 100644 --- a/examples/lpformer.py +++ b/examples/lpformer.py @@ -1,14 +1,13 @@ import random -import numpy as np -from tqdm import tqdm from argparse import ArgumentParser from collections import defaultdict +import numpy as np import torch +from ogb.linkproppred import Evaluator, PygLinkPropPredDataset from torch.utils.data import DataLoader from torch_sparse import SparseTensor - -from ogb.linkproppred import PygLinkPropPredDataset, Evaluator +from tqdm import tqdm from torch_geometric.nn.models import LPFormer @@ -16,8 +15,8 @@ parser.add_argument('--data_name', type=str, default='ogbl-ppa') parser.add_argument('--lr', type=float, default=1e-3) parser.add_argument('--epochs', type=int, default=100) -parser.add_argument('--runs', help="# random seeds to run over", - type=int, default=5) +parser.add_argument('--runs', help="# random seeds to run over", type=int, + default=5) parser.add_argument('--batch_size', type=int, default=32768) parser.add_argument('--hidden_channels', type=int, default=64) parser.add_argument('--gnn_layers', type=int, default=3) @@ -70,6 +69,7 @@ ppr_matrix = model.calc_sparse_ppr(data.edge_index, data.num_nodes, eps=args.eps) + def train_epoch(): model.train() train_pos = split_data['train_pos'].to(device) @@ -184,8 +184,7 @@ def set_seeds(seed): best_valid_test = eval_test print( - f"\nBest Performance:\n Valid={best_valid}\n Test={best_valid_test}" - ) + f"\nBest Performance:\n Valid={best_valid}\n Test={best_valid_test}") val_perf_runs.append(best_valid) test_perf_runs.append(best_valid_test) diff --git a/test/nn/models/test_lpformer.py b/test/nn/models/test_lpformer.py index 0c3c2d2d8d4bb..eac7ce6f3a52c 100644 --- a/test/nn/models/test_lpformer.py +++ b/test/nn/models/test_lpformer.py @@ -1,37 +1,38 @@ import torch import torch_geometric.typing -from torch_geometric.testing import withPackage from torch_geometric.nn import LPFormer +from torch_geometric.testing import withPackage from torch_geometric.typing import SparseTensor from torch_geometric.utils import to_undirected @withPackage('numba') # For ppr calculation def test_lpformer(): - model = LPFormer(16, 32, num_gnn_layers=2, - num_transformer_layers=1) - assert str(model) == 'LPFormer(16, 32, num_gnn_layers=2, num_transformer_layers=1)' + model = LPFormer(16, 32, num_gnn_layers=2, num_transformer_layers=1) + assert str( + model + ) == 'LPFormer(16, 32, num_gnn_layers=2, num_transformer_layers=1)' num_nodes = 20 x = torch.randn(num_nodes, 16) - edges = torch.randint(0, num_nodes-1, (2, 110)) + edges = torch.randint(0, num_nodes - 1, (2, 110)) edge_index, test_edges = edges[:, :100], edges[:, 100:] edge_index = to_undirected(edge_index) ppr_matrix = model.calc_sparse_ppr(edge_index, num_nodes, eps=1e-4) - + assert ppr_matrix.is_sparse assert ppr_matrix.size() == (num_nodes, num_nodes) assert ppr_matrix.sum().item() > 0 # Test with dense edge_index out = model(test_edges, x, edge_index, ppr_matrix) - assert out.size() == (10,) + assert out.size() == (10, ) # Test with sparse edge_index if torch_geometric.typing.WITH_TORCH_SPARSE: - adj = SparseTensor.from_edge_index(edge_index, + adj = SparseTensor.from_edge_index(edge_index, sparse_sizes=(num_nodes, num_nodes)) out2 = model(test_edges, x, adj, ppr_matrix) - assert out2.size() == (10,) + assert out2.size() == (10, ) diff --git a/torch_geometric/nn/models/lpformer.py b/torch_geometric/nn/models/lpformer.py index 711a9b1e3a51c..465bb5c87751b 100644 --- a/torch_geometric/nn/models/lpformer.py +++ b/torch_geometric/nn/models/lpformer.py @@ -6,13 +6,14 @@ import torch.nn.functional as F from torch import Tensor from torch.nn import Parameter + from torch_geometric.typing import SparseTensor -from ...utils import softmax, get_ppr, scatter -from ...typing import OptTensor, Tuple, Adj from ...nn.conv import MessagePassing from ...nn.dense.linear import Linear from ...nn.inits import glorot, zeros +from ...typing import Adj, OptTensor, Tuple +from ...utils import get_ppr, scatter, softmax from .basic_gnn import GCN @@ -183,8 +184,7 @@ def forward( return logits def propagate(self, x: Tensor, adj: Adj) -> Tensor: - """ - Propagate via GNN. + """Propagate via GNN. Args: x (Tensor): Node features @@ -353,15 +353,13 @@ def compute_node_mask( else: return (cn_ix, cn_src_ppr, cn_tgt_ppr), (onehop_ix, onehop_src_ppr, - onehop_tgt_ppr), (non1hop_ix, - non1hop_sppr, + onehop_tgt_ppr), (non1hop_ix, non1hop_sppr, non1hop_tppr) def get_ppr_vals( self, batch: Tensor, pair_diff_adj: Tensor, ppr_matrix: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]: - r""" - Get the src and tgt ppr vals. + r"""Get the src and tgt ppr vals. Returns the: link the node belongs to, type of node (e.g., CN), PPR relative to src, PPR relative to tgt. @@ -445,8 +443,7 @@ def get_structure_cnts( onehop_info: Tuple[Tensor, Tensor], non1hop_info: Optional[Tuple[Tensor, Tensor]], ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: - """ - Counts for CNs, 1-Hop, and >1-Hop that satisfy PPR threshold. + """Counts for CNs, 1-Hop, and >1-Hop that satisfy PPR threshold. Also include total # of neighbors @@ -482,8 +479,7 @@ def get_structure_cnts( def get_num_ppr_thresh(self, batch: Tensor, node_mask: Tensor, src_ppr: Tensor, tgt_ppr: Tensor, thresh: float) -> Tensor: - """ - Get # of nodes `v` where `ppr(a, v) >= eta` & `ppr(b, v) >= eta`. + """Get # of nodes `v` where `ppr(a, v) >= eta` & `ppr(b, v) >= eta`. Args: batch (Tensor): The batch vector. @@ -508,8 +504,7 @@ def get_count( node_mask: Tensor, batch: Tensor, ) -> Tensor: - """ - # of nodes for each sample in batch. + """# of nodes for each sample in batch. They node have already filtered by PPR beforehand @@ -596,8 +591,7 @@ def get_non_1hop_ppr(self, batch: Tensor, adj: Tensor, def calc_sparse_ppr(self, edge_index: Tensor, num_nodes: int, alpha: float = 0.15, eps: float = 5e-5) -> Tensor: - r""" - Calculate the PPR of the graph in sparse format. + r"""Calculate the PPR of the graph in sparse format. Args: edge_index: The edge indices @@ -615,8 +609,7 @@ def calc_sparse_ppr(self, edge_index: Tensor, num_nodes: int, class LPAttLayer(MessagePassing): - r""" - Attention Layer for pairwise interaction module. + r"""Attention Layer for pairwise interaction module. Args: in_channels (int): Size of input dimension @@ -692,8 +685,7 @@ def forward( node_feats: Tensor, ppr_rpes: Tensor, ) -> Tensor: - """ - Runs the forward pass of the module. + """Runs the forward pass of the module. Args: edge_index (Tensor): The edge indices. @@ -746,8 +738,7 @@ def message(self, x_i: Tensor, x_j: Tensor, ppr_rpes: Tensor, class MLP(nn.Module): - """ - L Layer MLP. + """L Layer MLP. """ def __init__(self, in_channels: int, hid_channels: int, out_channels: int, num_layers: int = 2, drop: int = 0, norm: str = "layer"): diff --git a/torch_geometric/nn/models/mlp.py b/torch_geometric/nn/models/mlp.py index f812183e0ddf9..7cc2f0da79ff1 100644 --- a/torch_geometric/nn/models/mlp.py +++ b/torch_geometric/nn/models/mlp.py @@ -248,4 +248,4 @@ def forward( return (x, emb) if isinstance(return_emb, bool) else x def __repr__(self) -> str: - return f'{self.__class__.__name__}({str(self.channel_list)[1:-1]})' \ No newline at end of file + return f'{self.__class__.__name__}({str(self.channel_list)[1:-1]})'