-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathecg_loss.py
62 lines (51 loc) · 2.41 KB
/
ecg_loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
""" Implementation of e-contaminated Gaussian distribution loss
to be used with Tensorflow or Keras with Tensorflow backend.
J. Tukey, "A survey of sampling from contaminated distributions"
Contributions to probability and statistics, vol. 2, pp. 448-485, 1960.
"""
import tensorflow as tf
import tensorflow.contrib.distributions as tfd
import numpy as np
from scipy.stats import multivariate_normal
def get_ecg_loss_func(ecg_c=10.0, ecg_epsilon=0.1):
""" Returns a function with two parameters (y_true and y_pred)
and parameters ecg_c and ecg_epsilon captured by closure.
This allows usage with Keras as: model.compile(loss=get_ecg_loss_func(5.0, 0.2)).
"""
def ecg_loss(y_true, y_pred):
num_dims = y_pred.get_shape().as_list()[1]
n = tfd.MultivariateNormalDiag(
loc=y_pred, scale_diag=tf.ones(num_dims)).prob(y_true)
nc = tfd.MultivariateNormalDiag(
loc=y_pred, scale_diag=tf.ones(num_dims) * ecg_c).prob(y_true)
return tf.reduce_mean(tf.log((1.0 - ecg_epsilon) * n + ecg_epsilon * nc) * -1.0)
return ecg_loss
def ecg_loss_np(y_true, y_pred, ecg_c=10.0, ecg_epsilon=0.1):
""" Naive numpy reference implementation. """
assert (y_true.shape == y_pred.shape)
losses = []
for row in range(y_true.shape[0]):
y_true_row = y_true[row, :]
y_pred_row = y_pred[row, :]
n = multivariate_normal.pdf(
y_true_row, mean=y_pred_row, cov=np.identity(len(y_pred_row)))
nc = multivariate_normal.pdf(
y_true_row, mean=y_pred_row, cov=np.identity(len(y_pred_row)) * ecg_c)
losses.append(np.log((1.0 - ecg_epsilon) *
n + ecg_epsilon * nc) * -1.0)
return np.mean(losses)
# Compares results of the numpy and tensorflow implementations
if __name__ == "__main__":
with tf.Session() as sess:
y_true = np.random.rand(50, 10)
y_pred = np.random.rand(50, 10)
loss_np = ecg_loss_np(y_true, y_pred)
print(loss_np)
y_true_tf = tf.placeholder(tf.float32, y_true.shape)
y_pred_tf = tf.placeholder(tf.float32, y_pred.shape)
loss_tf = get_ecg_loss_func()(y_true_tf, y_pred_tf)
loss_tf_result = sess.run(loss_tf,
feed_dict={y_true_tf: y_true,
y_pred_tf: y_pred})
print(loss_tf_result)
assert np.isclose(loss_np, loss_tf_result)