Skip to content
This repository was archived by the owner on Feb 21, 2025. It is now read-only.

Commit 23d0905

Browse files
frozenbugsSteve
and
Steve
authored
[Misc] Black auto fix. (dmlc#4642)
* [Misc] Black auto fix. * sort Co-authored-by: Steve <[email protected]>
1 parent a9f2acf commit 23d0905

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

99 files changed

+6243
-3552
lines changed

examples/pytorch/GATNE-T/src/main.py

+61-23
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,24 @@
1-
from collections import defaultdict
21
import math
32
import os
43
import sys
54
import time
5+
from collections import defaultdict
66

77
import numpy as np
88
import torch
99
import torch.nn as nn
1010
import torch.nn.functional as F
11-
from tqdm.auto import tqdm
1211
from numpy import random
1312
from torch.nn.parameter import Parameter
13+
from tqdm.auto import tqdm
14+
from utils import *
15+
1416
import dgl
1517
import dgl.function as fn
1618

17-
from utils import *
18-
1919

2020
def get_graph(network_data, vocab):
21-
""" Build graph, treat all nodes as the same type
21+
"""Build graph, treat all nodes as the same type
2222
2323
Parameters
2424
----------
@@ -57,7 +57,9 @@ def __init__(self, g, num_fanouts):
5757

5858
def sample(self, pairs):
5959
heads, tails, types = zip(*pairs)
60-
seeds, head_invmap = torch.unique(torch.LongTensor(heads), return_inverse=True)
60+
seeds, head_invmap = torch.unique(
61+
torch.LongTensor(heads), return_inverse=True
62+
)
6163
blocks = []
6264
for fanout in reversed(self.num_fanouts):
6365
sampled_graph = dgl.sampling.sample_neighbors(self.g, seeds, fanout)
@@ -90,7 +92,9 @@ def __init__(
9092
self.edge_type_count = edge_type_count
9193
self.dim_a = dim_a
9294

93-
self.node_embeddings = Parameter(torch.FloatTensor(num_nodes, embedding_size))
95+
self.node_embeddings = Parameter(
96+
torch.FloatTensor(num_nodes, embedding_size)
97+
)
9498
self.node_type_embeddings = Parameter(
9599
torch.FloatTensor(num_nodes, edge_type_count, embedding_u_size)
96100
)
@@ -100,16 +104,24 @@ def __init__(
100104
self.trans_weights_s1 = Parameter(
101105
torch.FloatTensor(edge_type_count, embedding_u_size, dim_a)
102106
)
103-
self.trans_weights_s2 = Parameter(torch.FloatTensor(edge_type_count, dim_a, 1))
107+
self.trans_weights_s2 = Parameter(
108+
torch.FloatTensor(edge_type_count, dim_a, 1)
109+
)
104110

105111
self.reset_parameters()
106112

107113
def reset_parameters(self):
108114
self.node_embeddings.data.uniform_(-1.0, 1.0)
109115
self.node_type_embeddings.data.uniform_(-1.0, 1.0)
110-
self.trans_weights.data.normal_(std=1.0 / math.sqrt(self.embedding_size))
111-
self.trans_weights_s1.data.normal_(std=1.0 / math.sqrt(self.embedding_size))
112-
self.trans_weights_s2.data.normal_(std=1.0 / math.sqrt(self.embedding_size))
116+
self.trans_weights.data.normal_(
117+
std=1.0 / math.sqrt(self.embedding_size)
118+
)
119+
self.trans_weights_s1.data.normal_(
120+
std=1.0 / math.sqrt(self.embedding_size)
121+
)
122+
self.trans_weights_s2.data.normal_(
123+
std=1.0 / math.sqrt(self.embedding_size)
124+
)
113125

114126
# embs: [batch_size, embedding_size]
115127
def forward(self, block):
@@ -122,10 +134,16 @@ def forward(self, block):
122134
with block.local_scope():
123135
for i in range(self.edge_type_count):
124136
edge_type = self.edge_types[i]
125-
block.srcdata[edge_type] = self.node_type_embeddings[input_nodes, i]
126-
block.dstdata[edge_type] = self.node_type_embeddings[output_nodes, i]
137+
block.srcdata[edge_type] = self.node_type_embeddings[
138+
input_nodes, i
139+
]
140+
block.dstdata[edge_type] = self.node_type_embeddings[
141+
output_nodes, i
142+
]
127143
block.update_all(
128-
fn.copy_u(edge_type, "m"), fn.sum("m", edge_type), etype=edge_type
144+
fn.copy_u(edge_type, "m"),
145+
fn.sum("m", edge_type),
146+
etype=edge_type,
129147
)
130148
node_type_embed.append(block.dstdata[edge_type])
131149

@@ -152,7 +170,9 @@ def forward(self, block):
152170
attention = (
153171
F.softmax(
154172
torch.matmul(
155-
torch.tanh(torch.matmul(tmp_node_type_embed, trans_w_s1)),
173+
torch.tanh(
174+
torch.matmul(tmp_node_type_embed, trans_w_s1)
175+
),
156176
trans_w_s2,
157177
)
158178
.squeeze(2)
@@ -173,7 +193,9 @@ def forward(self, block):
173193
)
174194
last_node_embed = F.normalize(node_embed, dim=2)
175195

176-
return last_node_embed # [batch_size, edge_type_count, embedding_size]
196+
return (
197+
last_node_embed # [batch_size, edge_type_count, embedding_size]
198+
)
177199

178200

179201
class NSLoss(nn.Module):
@@ -187,7 +209,8 @@ def __init__(self, num_nodes, num_sampled, embedding_size):
187209
self.sample_weights = F.normalize(
188210
torch.Tensor(
189211
[
190-
(math.log(k + 2) - math.log(k + 1)) / math.log(num_nodes + 1)
212+
(math.log(k + 2) - math.log(k + 1))
213+
/ math.log(num_nodes + 1)
191214
for k in range(num_nodes)
192215
]
193216
),
@@ -257,14 +280,20 @@ def train_model(network_data):
257280
pin_memory=True,
258281
)
259282
model = DGLGATNE(
260-
num_nodes, embedding_size, embedding_u_size, edge_types, edge_type_count, dim_a
283+
num_nodes,
284+
embedding_size,
285+
embedding_u_size,
286+
edge_types,
287+
edge_type_count,
288+
dim_a,
261289
)
262290
nsloss = NSLoss(num_nodes, num_sampled, embedding_size)
263291
model.to(device)
264292
nsloss.to(device)
265293

266294
optimizer = torch.optim.Adam(
267-
[{"params": model.parameters()}, {"params": nsloss.parameters()}], lr=1e-3
295+
[{"params": model.parameters()}, {"params": nsloss.parameters()}],
296+
lr=1e-3,
268297
)
269298

270299
best_score = 0
@@ -286,7 +315,10 @@ def train_model(network_data):
286315
block_types = block_types.to(device)
287316
embs = model(block[0].to(device))[head_invmap]
288317
embs = embs.gather(
289-
1, block_types.view(-1, 1, 1).expand(embs.shape[0], 1, embs.shape[2])
318+
1,
319+
block_types.view(-1, 1, 1).expand(
320+
embs.shape[0], 1, embs.shape[2]
321+
),
290322
)[:, 0]
291323
loss = nsloss(
292324
block[0].dstdata[dgl.NID][head_invmap].to(device),
@@ -307,15 +339,19 @@ def train_model(network_data):
307339

308340
model.eval()
309341
# {'1': {}, '2': {}}
310-
final_model = dict(zip(edge_types, [dict() for _ in range(edge_type_count)]))
342+
final_model = dict(
343+
zip(edge_types, [dict() for _ in range(edge_type_count)])
344+
)
311345
for i in range(num_nodes):
312346
train_inputs = (
313347
torch.tensor([i for _ in range(edge_type_count)])
314348
.unsqueeze(1)
315349
.to(device)
316350
) # [i, i]
317351
train_types = (
318-
torch.tensor(list(range(edge_type_count))).unsqueeze(1).to(device)
352+
torch.tensor(list(range(edge_type_count)))
353+
.unsqueeze(1)
354+
.to(device)
319355
) # [0, 1]
320356
pairs = torch.cat(
321357
(train_inputs, train_inputs, train_types), dim=1
@@ -343,7 +379,9 @@ def train_model(network_data):
343379
valid_aucs, valid_f1s, valid_prs = [], [], []
344380
test_aucs, test_f1s, test_prs = [], [], []
345381
for i in range(edge_type_count):
346-
if args.eval_type == "all" or edge_types[i] in args.eval_type.split(","):
382+
if args.eval_type == "all" or edge_types[i] in args.eval_type.split(
383+
","
384+
):
347385
tmp_auc, tmp_f1, tmp_pr = evaluate(
348386
final_model[edge_types[i]],
349387
valid_true_data_by_edge[edge_types[i]],

0 commit comments

Comments
 (0)