15
15
from model .loss import CrossEntropyLoss
16
16
from model .optimizer import Adadelta , SGD
17
17
from utils .config import setup
18
+ from utils .logger import make_logger
18
19
19
20
20
21
ctypes .cdll .LoadLibrary ('caffe2_nvrtc.dll' )
@@ -47,10 +48,12 @@ def val_model(model, data_loader):
47
48
48
49
total += len (labels )
49
50
correct += torch .sum (preds == labels .data )
51
+ res = ''
52
+ res += '{} Acc: {:.4f}%' .format ("val" , correct * 100.0 / total ) + '\n '
53
+ res += ('correct : {}, total : {}' .format (correct , total )) + '\n '
54
+ res += '\n '
55
+ return res
50
56
51
- print ('{} Acc: {:.4f}%' .format ("val" , correct * 100.0 / total ))
52
- print ('correct : {}, total : {}' .format (correct , total ))
53
- print ()
54
57
55
58
if __name__ == "__main__" :
56
59
setup ()
@@ -60,10 +63,10 @@ def val_model(model, data_loader):
60
63
test_loader = VGGFaceAdvDataLoader (data_dir + 'test/attack/' , batch_size , is_train = False )
61
64
orig_loader = VGGFaceDataLoader ('data/PubFig65/' , batch_size , is_train = False )
62
65
# lambda
63
- lamb = 0.001 # temporary
64
- d_tar = 30 .0
66
+ lamb = 1 # temporary
67
+ d_tar = 50 .0
65
68
lr = 0.01
66
- num_epoch = 100
69
+ num_epoch = 10
67
70
68
71
loss_fn = CrossEntropyLoss ()
69
72
@@ -74,7 +77,11 @@ def val_model(model, data_loader):
74
77
75
78
model .train ()
76
79
model [1 ].fc8 .register_forward_pre_hook (hook )
77
- optimizer = SGD (model .parameters (), lr = lr )
80
+ optimizer = SGD (model .parameters (), lr = lr , weight_decay = 1e-5 )
81
+
82
+ logger = make_logger ('aug_face' )
83
+ logger .info ('config\n lamb : {}\n d_tar : {}\n lr : {}\n num_epoch :{}\n ' \
84
+ .format (lamb , d_tar , lr , num_epoch ))
78
85
79
86
print ('attack validation' )
80
87
#val_model(model, test_loader)
@@ -86,8 +93,10 @@ def val_model(model, data_loader):
86
93
#val_model(model, orig_loader)
87
94
print ('''val Acc: 98.6154%\n correct : 641, total : 650\n ''' )
88
95
for epoch in range (num_epoch ):
89
- print ('Epoch {}/{}' .format (epoch + 1 , num_epoch ))
90
- print ('-' * 30 )
96
+ epoch_log = 'Epoch {}/{}' .format (epoch + 1 , num_epoch ) + '\n '
97
+ epoch_log += '-' * 30 + '\n '
98
+ print (epoch_log )
99
+
91
100
running_loss = 0.0
92
101
93
102
# train
@@ -103,32 +112,24 @@ def val_model(model, data_loader):
103
112
104
113
model .train ()
105
114
optimizer .zero_grad ()
106
- model .zero_grad ()
107
115
108
116
outputs = model (target_imgs )
109
117
loss_ce = loss_fn (outputs , target_labels )
110
118
111
119
dist = torch .dist (sk_source , sk_attack )
112
- loss_term = d_tar - dist
120
+ loss_term = torch . relu ( d_tar - dist )
113
121
114
122
loss = loss_ce + lamb * loss_term
115
123
loss .backward ()
116
124
optimizer .step ()
117
- running_loss += loss .item () * target_imgs .size (0 )
125
+ running_loss += loss .item () * target_imgs .size (0 ) / len ( target_imgs )
118
126
119
127
120
128
print (loss_ce .data , (lamb * loss_term ).data , loss .data )
121
- print ('train loss :' , running_loss )
122
-
123
- # validate
124
- print ('attack validation' )
125
- val_model (model , test_loader )
126
-
127
- print ('original prediction rate' )
128
- val_model (model , orig_loader )
129
- '''
130
- if epoch == 1:
131
- exit(0)
132
- '''
129
+ print ()
130
+ epoch_log += 'train loss : ' + str (running_loss ) + '\n '
131
+ epoch_log += 'attack validation\n ' + val_model (model , test_loader ) + '\n '
132
+ epoch_log += 'original prediction rate\n ' + val_model (model , orig_loader )
133
+ logger .info (epoch_log )
133
134
134
135
torch .save (model .state_dict (), 'saved/VGGFace_PubFig65_aug.pt' )
0 commit comments