1
- from collections import defaultdict
2
1
import math
3
2
import os
4
3
import sys
5
4
import time
5
+ from collections import defaultdict
6
6
7
7
import numpy as np
8
8
import torch
9
9
import torch .nn as nn
10
10
import torch .nn .functional as F
11
- from tqdm .auto import tqdm
12
11
from numpy import random
13
12
from torch .nn .parameter import Parameter
13
+ from tqdm .auto import tqdm
14
+ from utils import *
15
+
14
16
import dgl
15
17
import dgl .function as fn
16
18
17
- from utils import *
18
-
19
19
20
20
def get_graph (network_data , vocab ):
21
- """ Build graph, treat all nodes as the same type
21
+ """Build graph, treat all nodes as the same type
22
22
23
23
Parameters
24
24
----------
@@ -57,7 +57,9 @@ def __init__(self, g, num_fanouts):
57
57
58
58
def sample (self , pairs ):
59
59
heads , tails , types = zip (* pairs )
60
- seeds , head_invmap = torch .unique (torch .LongTensor (heads ), return_inverse = True )
60
+ seeds , head_invmap = torch .unique (
61
+ torch .LongTensor (heads ), return_inverse = True
62
+ )
61
63
blocks = []
62
64
for fanout in reversed (self .num_fanouts ):
63
65
sampled_graph = dgl .sampling .sample_neighbors (self .g , seeds , fanout )
@@ -90,7 +92,9 @@ def __init__(
90
92
self .edge_type_count = edge_type_count
91
93
self .dim_a = dim_a
92
94
93
- self .node_embeddings = Parameter (torch .FloatTensor (num_nodes , embedding_size ))
95
+ self .node_embeddings = Parameter (
96
+ torch .FloatTensor (num_nodes , embedding_size )
97
+ )
94
98
self .node_type_embeddings = Parameter (
95
99
torch .FloatTensor (num_nodes , edge_type_count , embedding_u_size )
96
100
)
@@ -100,16 +104,24 @@ def __init__(
100
104
self .trans_weights_s1 = Parameter (
101
105
torch .FloatTensor (edge_type_count , embedding_u_size , dim_a )
102
106
)
103
- self .trans_weights_s2 = Parameter (torch .FloatTensor (edge_type_count , dim_a , 1 ))
107
+ self .trans_weights_s2 = Parameter (
108
+ torch .FloatTensor (edge_type_count , dim_a , 1 )
109
+ )
104
110
105
111
self .reset_parameters ()
106
112
107
113
def reset_parameters (self ):
108
114
self .node_embeddings .data .uniform_ (- 1.0 , 1.0 )
109
115
self .node_type_embeddings .data .uniform_ (- 1.0 , 1.0 )
110
- self .trans_weights .data .normal_ (std = 1.0 / math .sqrt (self .embedding_size ))
111
- self .trans_weights_s1 .data .normal_ (std = 1.0 / math .sqrt (self .embedding_size ))
112
- self .trans_weights_s2 .data .normal_ (std = 1.0 / math .sqrt (self .embedding_size ))
116
+ self .trans_weights .data .normal_ (
117
+ std = 1.0 / math .sqrt (self .embedding_size )
118
+ )
119
+ self .trans_weights_s1 .data .normal_ (
120
+ std = 1.0 / math .sqrt (self .embedding_size )
121
+ )
122
+ self .trans_weights_s2 .data .normal_ (
123
+ std = 1.0 / math .sqrt (self .embedding_size )
124
+ )
113
125
114
126
# embs: [batch_size, embedding_size]
115
127
def forward (self , block ):
@@ -122,10 +134,16 @@ def forward(self, block):
122
134
with block .local_scope ():
123
135
for i in range (self .edge_type_count ):
124
136
edge_type = self .edge_types [i ]
125
- block .srcdata [edge_type ] = self .node_type_embeddings [input_nodes , i ]
126
- block .dstdata [edge_type ] = self .node_type_embeddings [output_nodes , i ]
137
+ block .srcdata [edge_type ] = self .node_type_embeddings [
138
+ input_nodes , i
139
+ ]
140
+ block .dstdata [edge_type ] = self .node_type_embeddings [
141
+ output_nodes , i
142
+ ]
127
143
block .update_all (
128
- fn .copy_u (edge_type , "m" ), fn .sum ("m" , edge_type ), etype = edge_type
144
+ fn .copy_u (edge_type , "m" ),
145
+ fn .sum ("m" , edge_type ),
146
+ etype = edge_type ,
129
147
)
130
148
node_type_embed .append (block .dstdata [edge_type ])
131
149
@@ -152,7 +170,9 @@ def forward(self, block):
152
170
attention = (
153
171
F .softmax (
154
172
torch .matmul (
155
- torch .tanh (torch .matmul (tmp_node_type_embed , trans_w_s1 )),
173
+ torch .tanh (
174
+ torch .matmul (tmp_node_type_embed , trans_w_s1 )
175
+ ),
156
176
trans_w_s2 ,
157
177
)
158
178
.squeeze (2 )
@@ -173,7 +193,9 @@ def forward(self, block):
173
193
)
174
194
last_node_embed = F .normalize (node_embed , dim = 2 )
175
195
176
- return last_node_embed # [batch_size, edge_type_count, embedding_size]
196
+ return (
197
+ last_node_embed # [batch_size, edge_type_count, embedding_size]
198
+ )
177
199
178
200
179
201
class NSLoss (nn .Module ):
@@ -187,7 +209,8 @@ def __init__(self, num_nodes, num_sampled, embedding_size):
187
209
self .sample_weights = F .normalize (
188
210
torch .Tensor (
189
211
[
190
- (math .log (k + 2 ) - math .log (k + 1 )) / math .log (num_nodes + 1 )
212
+ (math .log (k + 2 ) - math .log (k + 1 ))
213
+ / math .log (num_nodes + 1 )
191
214
for k in range (num_nodes )
192
215
]
193
216
),
@@ -257,14 +280,20 @@ def train_model(network_data):
257
280
pin_memory = True ,
258
281
)
259
282
model = DGLGATNE (
260
- num_nodes , embedding_size , embedding_u_size , edge_types , edge_type_count , dim_a
283
+ num_nodes ,
284
+ embedding_size ,
285
+ embedding_u_size ,
286
+ edge_types ,
287
+ edge_type_count ,
288
+ dim_a ,
261
289
)
262
290
nsloss = NSLoss (num_nodes , num_sampled , embedding_size )
263
291
model .to (device )
264
292
nsloss .to (device )
265
293
266
294
optimizer = torch .optim .Adam (
267
- [{"params" : model .parameters ()}, {"params" : nsloss .parameters ()}], lr = 1e-3
295
+ [{"params" : model .parameters ()}, {"params" : nsloss .parameters ()}],
296
+ lr = 1e-3 ,
268
297
)
269
298
270
299
best_score = 0
@@ -286,7 +315,10 @@ def train_model(network_data):
286
315
block_types = block_types .to (device )
287
316
embs = model (block [0 ].to (device ))[head_invmap ]
288
317
embs = embs .gather (
289
- 1 , block_types .view (- 1 , 1 , 1 ).expand (embs .shape [0 ], 1 , embs .shape [2 ])
318
+ 1 ,
319
+ block_types .view (- 1 , 1 , 1 ).expand (
320
+ embs .shape [0 ], 1 , embs .shape [2 ]
321
+ ),
290
322
)[:, 0 ]
291
323
loss = nsloss (
292
324
block [0 ].dstdata [dgl .NID ][head_invmap ].to (device ),
@@ -307,15 +339,19 @@ def train_model(network_data):
307
339
308
340
model .eval ()
309
341
# {'1': {}, '2': {}}
310
- final_model = dict (zip (edge_types , [dict () for _ in range (edge_type_count )]))
342
+ final_model = dict (
343
+ zip (edge_types , [dict () for _ in range (edge_type_count )])
344
+ )
311
345
for i in range (num_nodes ):
312
346
train_inputs = (
313
347
torch .tensor ([i for _ in range (edge_type_count )])
314
348
.unsqueeze (1 )
315
349
.to (device )
316
350
) # [i, i]
317
351
train_types = (
318
- torch .tensor (list (range (edge_type_count ))).unsqueeze (1 ).to (device )
352
+ torch .tensor (list (range (edge_type_count )))
353
+ .unsqueeze (1 )
354
+ .to (device )
319
355
) # [0, 1]
320
356
pairs = torch .cat (
321
357
(train_inputs , train_inputs , train_types ), dim = 1
@@ -343,7 +379,9 @@ def train_model(network_data):
343
379
valid_aucs , valid_f1s , valid_prs = [], [], []
344
380
test_aucs , test_f1s , test_prs = [], [], []
345
381
for i in range (edge_type_count ):
346
- if args .eval_type == "all" or edge_types [i ] in args .eval_type .split ("," ):
382
+ if args .eval_type == "all" or edge_types [i ] in args .eval_type .split (
383
+ ","
384
+ ):
347
385
tmp_auc , tmp_f1 , tmp_pr = evaluate (
348
386
final_model [edge_types [i ]],
349
387
valid_true_data_by_edge [edge_types [i ]],
0 commit comments