This repository was archived by the owner on Feb 21, 2025. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathutils.py
100 lines (78 loc) · 3.09 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import copy
import numpy as np
import torch
from dgl.data import (AmazonCoBuyComputerDataset, AmazonCoBuyPhotoDataset,
CoauthorCSDataset, CoauthorPhysicsDataset, PPIDataset,
WikiCSDataset)
from dgl.dataloading import GraphDataLoader
from dgl.transforms import Compose, DropEdge, FeatMask, RowFeatNormalizer
class CosineDecayScheduler:
def __init__(self, max_val, warmup_steps, total_steps):
self.max_val = max_val
self.warmup_steps = warmup_steps
self.total_steps = total_steps
def get(self, step):
if step < self.warmup_steps:
return self.max_val * step / self.warmup_steps
elif self.warmup_steps <= step <= self.total_steps:
return (
self.max_val
* (
1
+ np.cos(
(step - self.warmup_steps)
* np.pi
/ (self.total_steps - self.warmup_steps)
)
)
/ 2
)
else:
raise ValueError(
"Step ({}) > total number of steps ({}).".format(
step, self.total_steps
)
)
def get_graph_drop_transform(drop_edge_p, feat_mask_p):
transforms = list()
# make copy of graph
transforms.append(copy.deepcopy)
# drop edges
if drop_edge_p > 0.0:
transforms.append(DropEdge(drop_edge_p))
# drop features
if feat_mask_p > 0.0:
transforms.append(FeatMask(feat_mask_p, node_feat_names=["feat"]))
return Compose(transforms)
def get_wiki_cs(transform=RowFeatNormalizer(subtract_min=True)):
dataset = WikiCSDataset(transform=transform)
g = dataset[0]
std, mean = torch.std_mean(g.ndata["feat"], dim=0, unbiased=False)
g.ndata["feat"] = (g.ndata["feat"] - mean) / std
return [g]
def get_ppi():
train_dataset = PPIDataset(mode="train")
val_dataset = PPIDataset(mode="valid")
test_dataset = PPIDataset(mode="test")
train_val_dataset = [i for i in train_dataset] + [i for i in val_dataset]
for idx, data in enumerate(train_val_dataset):
data.ndata["batch"] = torch.zeros(data.number_of_nodes()) + idx
data.ndata["batch"] = data.ndata["batch"].long()
g = list(GraphDataLoader(train_val_dataset, batch_size=22, shuffle=True))
return g, PPIDataset(mode="train"), PPIDataset(mode="valid"), test_dataset
def get_dataset(name, transform=RowFeatNormalizer(subtract_min=True)):
dgl_dataset_dict = {
"coauthor_cs": CoauthorCSDataset,
"coauthor_physics": CoauthorPhysicsDataset,
"amazon_computers": AmazonCoBuyComputerDataset,
"amazon_photos": AmazonCoBuyPhotoDataset,
"wiki_cs": get_wiki_cs,
"ppi": get_ppi,
}
dataset_class = dgl_dataset_dict[name]
train_data, val_data, test_data = None, None, None
if name != "ppi":
dataset = dataset_class(transform=transform)
else:
dataset, train_data, val_data, test_data = dataset_class()
return dataset, train_data, val_data, test_data