-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
123 lines (99 loc) · 3.83 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import argparse
from scipy.sparse.csgraph import shortest_path
import numpy as np
import pandas as pd
import torch
import dgl
from ogb.linkproppred import DglLinkPropPredDataset, Evaluator
def parse_arguments():
"""
Parse arguments
"""
parser = argparse.ArgumentParser(description='SEAL')
parser.add_argument('--dataset', type=str, default='ogbl-collab')
parser.add_argument('--gpu_id', type=int, default=0)
parser.add_argument('--hop', type=int, default=1)
parser.add_argument('--model', type=str, default='dgcnn')
parser.add_argument('--gcn_type', type=str, default='gcn')
parser.add_argument('--num_layers', type=int, default=3)
parser.add_argument('--hidden_units', type=int, default=32)
parser.add_argument('--sort_k', type=int, default=30)
parser.add_argument('--pooling', type=str, default='sum')
parser.add_argument('--dropout', type=str, default=0.5)
parser.add_argument('--hits_k', type=int, default=50)
parser.add_argument('--lr', type=float, default=0.0001)
parser.add_argument('--neg_samples', type=int, default=1)
parser.add_argument('--subsample_ratio', type=float, default=0.1)
parser.add_argument('--epochs', type=int, default=60)
parser.add_argument('--batch_size', type=int, default=32)
parser.add_argument('--eval_steps', type=int, default=5)
parser.add_argument('--num_workers', type=int, default=32)
parser.add_argument('--random_seed', type=int, default=2021)
parser.add_argument('--save_dir', type=str, default='./processed')
args = parser.parse_args()
return args
def load_ogb_dataset(dataset):
"""
Load OGB dataset
Args:
dataset(str): name of dataset (ogbl-collab, ogbl-ddi, ogbl-citation)
Returns:
graph(DGLGraph): graph
split_edge(dict): split edge
"""
dataset = DglLinkPropPredDataset(name=dataset)
split_edge = dataset.get_edge_split()
graph = dataset[0]
return graph, split_edge
def drnl_node_labeling(subgraph, src, dst):
"""
Double Radius Node labeling
d = r(i,u)+r(i,v)
label = 1+ min(r(i,u),r(i,v))+ (d//2)*(d//2+d%2-1)
Isolated nodes in subgraph will be set as zero.
Extreme large graph may cause memory error.
Args:
subgraph(DGLGraph): The graph
src(int): node id of one of src node in new subgraph
dst(int): node id of one of dst node in new subgraph
Returns:
z(Tensor): node labeling tensor
"""
adj = subgraph.adj().to_dense().numpy()
src, dst = (dst, src) if src > dst else (src, dst)
idx = list(range(src)) + list(range(src + 1, adj.shape[0]))
adj_wo_src = adj[idx, :][:, idx]
idx = list(range(dst)) + list(range(dst + 1, adj.shape[0]))
adj_wo_dst = adj[idx, :][:, idx]
dist2src = shortest_path(adj_wo_dst, directed=False, unweighted=True, indices=src)
dist2src = np.insert(dist2src, dst, 0, axis=0)
dist2src = torch.from_numpy(dist2src)
dist2dst = shortest_path(adj_wo_src, directed=False, unweighted=True, indices=dst - 1)
dist2dst = np.insert(dist2dst, src, 0, axis=0)
dist2dst = torch.from_numpy(dist2dst)
dist = dist2src + dist2dst
dist_over_2, dist_mod_2 = dist // 2, dist % 2
z = 1 + torch.min(dist2src, dist2dst)
z += dist_over_2 * (dist_over_2 + dist_mod_2 - 1)
z[src] = 1.
z[dst] = 1.
z[torch.isnan(z)] = 0.
return z.to(torch.long)
def evaluate_hits(name, pos_pred, neg_pred, K):
"""
Compute hits
Args:
name(str): name of dataset
pos_pred(Tensor): predict value of positive edges
neg_pred(Tensor): predict value of negative edges
K(int): num of hits
Returns:
hits(float): score of hits
"""
evaluator = Evaluator(name)
evaluator.K = K
hits = evaluator.eval({
'y_pred_pos': pos_pred,
'y_pred_neg': neg_pred,
})[f'hits@{K}']
return hits