diff --git a/src/model.py b/src/model.py index 230b83cc2..4b9266279 100644 --- a/src/model.py +++ b/src/model.py @@ -32,8 +32,9 @@ def norm(x, scope, *, axis=-1, epsilon=1e-5): g = tf.get_variable('g', [n_state], initializer=tf.constant_initializer(1)) b = tf.get_variable('b', [n_state], initializer=tf.constant_initializer(0)) u = tf.reduce_mean(x, axis=axis, keepdims=True) - s = tf.reduce_mean(tf.square(x-u), axis=axis, keepdims=True) - x = (x - u) * tf.rsqrt(s + epsilon) + n = x-u + s = tf.reduce_mean(tf.square(n), axis=axis, keepdims=True) + x = n * tf.rsqrt(s + epsilon) x = x*g + b return x