-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathlosses.py
56 lines (45 loc) · 1.96 KB
/
losses.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
# Copyright (c) 2018, Curious AI Ltd. All rights reserved.
#
# This work is licensed under the Creative Commons Attribution-NonCommercial
# 4.0 International License. To view a copy of this license, visit
# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
"""Custom loss functions"""
import torch
from torch.nn import functional as F
from torch.autograd import Variable
# import pdb
def softmax_mse_loss(input_logits, target_logits):
"""Takes softmax on both sides and returns MSE loss
Note:
- Returns the sum over all examples. Divide by the batch size afterwards
if you want the mean.
- Sends gradients to inputs but not the targets.
"""
#pdb.set_trace()
assert input_logits.size() == target_logits.size()
input_softmax = F.softmax(input_logits, dim=1)
target_softmax = F.softmax(target_logits, dim=1)
num_classes = input_logits.size()[1]
return F.mse_loss(input_softmax, target_softmax, size_average=False) / num_classes
def softmax_kl_loss(input_logits, target_logits):
"""Takes softmax on both sides and returns KL divergence
Note:
- Returns the sum over all examples. Divide by the batch size afterwards
if you want the mean.
- Sends gradients to inputs but not the targets.
"""
assert input_logits.size() == target_logits.size()
input_log_softmax = F.log_softmax(input_logits, dim=1)
target_softmax = F.softmax(target_logits, dim=1)
return F.kl_div(input_log_softmax, target_softmax, size_average=False)
def symmetric_mse_loss(input1, input2):
"""Like F.mse_loss but sends gradients to both directions
Note:
- Returns the sum over all examples. Divide by the batch size afterwards
if you want the mean.
- Sends gradients to both input1 and input2.
"""
assert input1.size() == input2.size()
num_classes = input1.size()[1]
return torch.sum((input1 - input2)**2) / num_classes