Skip to content

Commit 17f4e0c

Browse files
committed
Add logging system
1 parent 07203bb commit 17f4e0c

6 files changed

+41
-32
lines changed

perturb.png

65.9 KB
Loading

source.png

58.3 KB
Loading

target.png

59.7 KB
Loading

train_attack_face.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def forward(self, img, tk_img, tk_target):
6262

6363

6464
if __name__=="__main__":
65-
data_loader = VGGFaceDataLoader(data_dir, batch_size, is_train=False)
65+
data_loader = VGGFaceDataLoader(data_dir, batch_size, is_train=True)
6666

6767
# lambda
6868
budget = 0.01
@@ -129,12 +129,13 @@ def forward(self, img, tk_img, tk_target):
129129
print('batch {} : loss {}'.format(batch_idx, loss))
130130
print()
131131

132+
132133
# save img
133134
for i in range(len(perturb_imgs)):
134135
filename = '{}_{}_{}.png'.format(batch_idx, source_labels[i], target_labels[i])
135-
save_image(perturb_imgs[i], 'data/PubFig65_adv2/test/attack/' + filename)
136-
save_image(target_imgs[i], 'data/PubFig65_adv2/test/target/' + filename)
137-
136+
save_image(perturb_imgs[i], 'data/PubFig65_adv2/train/attack/' + filename)
137+
save_image(target_imgs[i], 'data/PubFig65_adv2/train/target/' + filename)
138+
138139

139140
print('total prediction rate {}/{}'.format(correct, total))
140141
print()

train_aug_face.py

+25-24
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from model.loss import CrossEntropyLoss
1616
from model.optimizer import Adadelta, SGD
1717
from utils.config import setup
18+
from utils.logger import make_logger
1819

1920

2021
ctypes.cdll.LoadLibrary('caffe2_nvrtc.dll')
@@ -47,10 +48,12 @@ def val_model(model, data_loader):
4748

4849
total += len(labels)
4950
correct += torch.sum(preds == labels.data)
51+
res = ''
52+
res += '{} Acc: {:.4f}%'.format("val", correct * 100.0 / total) + '\n'
53+
res += ('correct : {}, total : {}'.format(correct, total)) + '\n'
54+
res += '\n'
55+
return res
5056

51-
print('{} Acc: {:.4f}%'.format("val", correct * 100.0 / total))
52-
print('correct : {}, total : {}'.format(correct, total))
53-
print()
5457

5558
if __name__=="__main__":
5659
setup()
@@ -60,10 +63,10 @@ def val_model(model, data_loader):
6063
test_loader = VGGFaceAdvDataLoader(data_dir + 'test/attack/', batch_size, is_train=False)
6164
orig_loader = VGGFaceDataLoader('data/PubFig65/', batch_size, is_train=False)
6265
# lambda
63-
lamb = 0.001 # temporary
64-
d_tar = 30.0
66+
lamb = 1 # temporary
67+
d_tar = 50.0
6568
lr = 0.01
66-
num_epoch = 100
69+
num_epoch = 10
6770

6871
loss_fn = CrossEntropyLoss()
6972

@@ -74,7 +77,11 @@ def val_model(model, data_loader):
7477

7578
model.train()
7679
model[1].fc8.register_forward_pre_hook(hook)
77-
optimizer = SGD(model.parameters(), lr=lr)
80+
optimizer = SGD(model.parameters(), lr=lr, weight_decay=1e-5)
81+
82+
logger = make_logger('aug_face')
83+
logger.info('config\nlamb : {}\nd_tar : {}\nlr : {}\nnum_epoch :{}\n'\
84+
.format(lamb, d_tar, lr, num_epoch))
7885

7986
print('attack validation')
8087
#val_model(model, test_loader)
@@ -86,8 +93,10 @@ def val_model(model, data_loader):
8693
#val_model(model, orig_loader)
8794
print('''val Acc: 98.6154%\ncorrect : 641, total : 650\n''')
8895
for epoch in range(num_epoch):
89-
print('Epoch {}/{}'.format(epoch + 1, num_epoch))
90-
print('-' * 30)
96+
epoch_log = 'Epoch {}/{}'.format(epoch + 1, num_epoch) + '\n'
97+
epoch_log += '-' * 30 + '\n'
98+
print(epoch_log)
99+
91100
running_loss = 0.0
92101

93102
# train
@@ -103,32 +112,24 @@ def val_model(model, data_loader):
103112

104113
model.train()
105114
optimizer.zero_grad()
106-
model.zero_grad()
107115

108116
outputs = model(target_imgs)
109117
loss_ce = loss_fn(outputs, target_labels)
110118

111119
dist = torch.dist(sk_source, sk_attack)
112-
loss_term = d_tar - dist
120+
loss_term = torch.relu(d_tar - dist)
113121

114122
loss = loss_ce + lamb * loss_term
115123
loss.backward()
116124
optimizer.step()
117-
running_loss += loss.item() * target_imgs.size(0)
125+
running_loss += loss.item() * target_imgs.size(0) / len(target_imgs)
118126

119127

120128
print(loss_ce.data, (lamb * loss_term).data, loss.data)
121-
print('train loss :', running_loss)
122-
123-
# validate
124-
print('attack validation')
125-
val_model(model, test_loader)
126-
127-
print('original prediction rate')
128-
val_model(model, orig_loader)
129-
'''
130-
if epoch == 1:
131-
exit(0)
132-
'''
129+
print()
130+
epoch_log += 'train loss : ' + str(running_loss) + '\n'
131+
epoch_log += 'attack validation\n' + val_model(model, test_loader) + '\n'
132+
epoch_log += 'original prediction rate\n' + val_model(model, orig_loader)
133+
logger.info(epoch_log)
133134

134135
torch.save(model.state_dict(), 'saved/VGGFace_PubFig65_aug.pt')

utils/logger.py

+11-4
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,25 @@
11
import logging
2+
import datetime
23

3-
def make_logger(name=None):
4+
log_dir = 'logs/'
5+
basename = "mylogfile"
6+
7+
def make_logger(basename=None):
48
#1 logger instance를 만든다.
5-
logger = logging.getLogger(name)
9+
suffix = datetime.datetime.now().strftime("%y%m%d_%H%M%S")
10+
filename = log_dir + "_".join([basename, suffix, '.log']) # e.g. 'mylogfile_120508_171442'
11+
12+
logger = logging.getLogger(basename)
613

714
#2 logger의 level을 가장 낮은 수준인 DEBUG로 설정해둔다.
815
logger.setLevel(logging.DEBUG)
916

1017
#3 formatter 지정
11-
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
18+
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s\n%(message)s")
1219

1320
#4 handler instance 생성
1421
console = logging.StreamHandler()
15-
file_handler = logging.FileHandler(filename="test.log")
22+
file_handler = logging.FileHandler(filename=filename)
1623

1724
#5 handler 별로 다른 level 설정
1825
console.setLevel(logging.INFO)

0 commit comments

Comments
 (0)