From 94e7c90541baa87a5935af1a4ef932110a37ddd1 Mon Sep 17 00:00:00 2001 From: rewat7 Date: Sat, 10 Feb 2024 12:44:55 +0530 Subject: [PATCH] Added skew geometric jenson shannon divergence for calculating loss --- .../layers/base_variational_layer.py | 32 ++++ .../layers/flipout_layers/conv_flipout.py | 57 ++++++-- .../layers/flipout_layers/linear_flipout.py | 29 +++- .../layers/flipout_layers/rnn_flipout.py | 2 + .../variational_layers/conv_variational.py | 137 ++++++++++++++---- .../variational_layers/linear_variational.py | 24 ++- .../variational_layers/rnn_variational.py | 2 + 7 files changed, 235 insertions(+), 48 deletions(-) diff --git a/bayesian_torch/layers/base_variational_layer.py b/bayesian_torch/layers/base_variational_layer.py index 4bbaea8..1286d05 100644 --- a/bayesian_torch/layers/base_variational_layer.py +++ b/bayesian_torch/layers/base_variational_layer.py @@ -66,3 +66,35 @@ def kl_div(self, mu_q, sigma_q, mu_p, sigma_p): sigma_q) + (sigma_q**2 + (mu_q - mu_p)**2) / (2 * (sigma_p**2)) - 0.5 return kl.mean() + + def jsg_div(self, mu_q, sigma_q, mu_p, sigma_p, alpha=0.5): + ''' + Calculates skew geometric jenson-shannon divergence between two gaussians (Q||P) + + Parameters: + * mu_q: torch.Tensor -> mu parameter of distribution Q + * sigma_q: torch.Tensor -> sigma parameter of distribution Q + * mu_p: float -> mu parameter of distribution P + * sigma_p: float -> sigma parameter of distribution P + + returns torch.Tensor of shape 0 + ''' + + sigma_0_alpha = (sigma_q.pow(2) * sigma_p.pow(2)) \ + / ((1-alpha)*sigma_q.pow(2) + alpha*sigma_p.pow(2)) + + mu_0_alpha = sigma_0_alpha * ((alpha*mu_q/sigma_q.pow(2)) \ + + ((1-alpha)*mu_p/(sigma_p.pow(2)))) + + term1 = ((1-alpha)*sigma_q.pow(2) + alpha*sigma_p.pow(2)) / sigma_0_alpha + + term2 = torch.log(sigma_0_alpha / (torch.pow(sigma_q, 2-2*alpha) \ + * sigma_p.pow(2*alpha))) + + term3 = (1-alpha)*(mu_0_alpha - mu_q).pow(2) / sigma_0_alpha + + term4 = alpha*(mu_0_alpha - mu_p).pow(2) / sigma_0_alpha + + jsg_divergence = 0.5 * (term1 + term2 + term3 + term4 - 1) + + return jsg_divergence.mean() diff --git a/bayesian_torch/layers/flipout_layers/conv_flipout.py b/bayesian_torch/layers/flipout_layers/conv_flipout.py index 0bf7266..3e06a75 100644 --- a/bayesian_torch/layers/flipout_layers/conv_flipout.py +++ b/bayesian_torch/layers/flipout_layers/conv_flipout.py @@ -67,6 +67,7 @@ def __init__(self, prior_variance=1, posterior_mu_init=0, posterior_rho_init=-3.0, + use_jsg=False, bias=True): """ Implements Conv1d layer with Flipout reparameterization trick. @@ -102,6 +103,7 @@ def __init__(self, self.posterior_mu_init = posterior_mu_init self.posterior_rho_init = posterior_rho_init self.bias = bias + self.jsg=use_jsg self.kl = 0 @@ -197,8 +199,13 @@ def forward(self, x, return_kl=True): delta_kernel = (sigma_weight * eps_kernel) if return_kl: - kl = self.kl_div(self.mu_kernel, sigma_weight, self.prior_weight_mu, - self.prior_weight_sigma) + if self.jsg: + kl=self.jsg_div(self.mu_weight, sigma_weight, self.prior_weight_mu, + self.prior_weight_sigma) + + else: + kl = self.kl_div(self.mu_kernel, sigma_weight, self.prior_weight_mu, + self.prior_weight_sigma) bias = None if self.bias: @@ -206,9 +213,14 @@ def forward(self, x, return_kl=True): eps_bias = self.eps_bias.data.normal_() bias = (sigma_bias * eps_bias) if return_kl: - kl = kl + self.kl_div(self.mu_bias, sigma_bias, self.prior_bias_mu, + if self.jsg: + kl = kl + self.jsg_div(self.mu_bias, sigma_bias, self.prior_bias_mu, self.prior_bias_sigma) + else: + kl = kl + self.kl_div(self.mu_bias, sigma_bias, self.prior_bias_mu, + self.prior_bias_sigma) + # perturbed feedforward x_tmp = x * sign_input perturbed_outputs_tmp = F.conv1d(x * sign_input, @@ -257,6 +269,7 @@ def __init__(self, prior_variance=1, posterior_mu_init=0, posterior_rho_init=-3.0, + use_jsg=False, bias=True): """ Implements Conv2d layer with Flipout reparameterization trick. @@ -292,6 +305,7 @@ def __init__(self, self.posterior_mu_init = posterior_mu_init self.posterior_rho_init = posterior_rho_init self.bias = bias + self.jsg=use_jsg self.kl = 0 kernel_size = get_kernel_size(kernel_size, 2) @@ -390,10 +404,15 @@ def forward(self, x, return_kl=True): eps_kernel = self.eps_kernel.data.normal_() delta_kernel = (sigma_weight * eps_kernel) - + if return_kl: - kl = self.kl_div(self.mu_kernel, sigma_weight, self.prior_weight_mu, - self.prior_weight_sigma) + if self.jsg: + kl=self.jsg_div(self.mu_weight, sigma_weight, self.prior_weight_mu, + self.prior_weight_sigma) + + else: + kl = self.kl_div(self.mu_kernel, sigma_weight, self.prior_weight_mu, + self.prior_weight_sigma) bias = None if self.bias: @@ -401,9 +420,14 @@ def forward(self, x, return_kl=True): eps_bias = self.eps_bias.data.normal_() bias = (sigma_bias * eps_bias) if return_kl: - kl = kl + self.kl_div(self.mu_bias, sigma_bias, self.prior_bias_mu, + if self.jsg: + kl = kl + self.jsg_div(self.mu_bias, sigma_bias, self.prior_bias_mu, self.prior_bias_sigma) + else: + kl = kl + self.kl_div(self.mu_bias, sigma_bias, self.prior_bias_mu, + self.prior_bias_sigma) + # perturbed feedforward x_tmp = x * sign_input perturbed_outputs_tmp = F.conv2d(x * sign_input, @@ -453,6 +477,7 @@ def __init__(self, prior_variance=1, posterior_mu_init=0, posterior_rho_init=-3.0, + use_jsg=False, bias=True): """ Implements Conv3d layer with Flipout reparameterization trick. @@ -483,6 +508,7 @@ def __init__(self, self.dilation = dilation self.groups = groups self.bias = bias + self.jsg=use_jsg self.kl = 0 @@ -590,8 +616,13 @@ def forward(self, x, return_kl=True): delta_kernel = (sigma_weight * eps_kernel) if return_kl: - kl = self.kl_div(self.mu_kernel, sigma_weight, self.prior_weight_mu, - self.prior_weight_sigma) + if self.jsg: + kl=self.jsg_div(self.mu_weight, sigma_weight, self.prior_weight_mu, + self.prior_weight_sigma) + + else: + kl = self.kl_div(self.mu_kernel, sigma_weight, self.prior_weight_mu, + self.prior_weight_sigma) bias = None if self.bias: @@ -599,9 +630,15 @@ def forward(self, x, return_kl=True): eps_bias = self.eps_bias.data.normal_() bias = (sigma_bias * eps_bias) if return_kl: - kl = kl + self.kl_div(self.mu_bias, sigma_bias, self.prior_bias_mu, + if self.jsg: + kl = kl + self.jsg_div(self.mu_bias, sigma_bias, self.prior_bias_mu, self.prior_bias_sigma) + else: + kl = kl + self.kl_div(self.mu_bias, sigma_bias, self.prior_bias_mu, + self.prior_bias_sigma) + + # perturbed feedforward x_tmp = x * sign_input perturbed_outputs_tmp = F.conv3d(x * sign_input, diff --git a/bayesian_torch/layers/flipout_layers/linear_flipout.py b/bayesian_torch/layers/flipout_layers/linear_flipout.py index 20ec7bd..ae9b81b 100644 --- a/bayesian_torch/layers/flipout_layers/linear_flipout.py +++ b/bayesian_torch/layers/flipout_layers/linear_flipout.py @@ -54,6 +54,7 @@ def __init__(self, prior_variance=1, posterior_mu_init=0, posterior_rho_init=-3.0, + use_jsg=False, bias=True): """ Implements Linear layer with Flipout reparameterization trick. @@ -91,6 +92,8 @@ def __init__(self, self.register_buffer('prior_weight_sigma', torch.Tensor(out_features, in_features), persistent=False) + + self.jsg=use_jsg if bias: self.mu_bias = nn.Parameter(torch.Tensor(out_features)) @@ -136,10 +139,16 @@ def init_parameters(self): def kl_loss(self): sigma_weight = torch.log1p(torch.exp(self.rho_weight)) - kl = self.kl_div(self.mu_weight, sigma_weight, self.prior_weight_mu, self.prior_weight_sigma) + if self.jsg: + kl = self.jsg_div(self.mu_weight, sigma_weight, self.prior_weight_mu, self.prior_weight_sigma) + else: + kl = self.kl_div(self.mu_weight, sigma_weight, self.prior_weight_mu, self.prior_weight_sigma) if self.mu_bias is not None: sigma_bias = torch.log1p(torch.exp(self.rho_bias)) - kl += self.kl_div(self.mu_bias, sigma_bias, self.prior_bias_mu, self.prior_bias_sigma) + if self.jsg: + kl += self.kl_div(self.mu_bias, sigma_bias, self.prior_bias_mu, self.prior_bias_sigma) + else: + kl += self.kl_div(self.mu_bias, sigma_bias, self.prior_bias_mu, self.prior_bias_sigma) return kl def forward(self, x, return_kl=True): @@ -153,17 +162,27 @@ def forward(self, x, return_kl=True): # get kl divergence if return_kl: - kl = self.kl_div(self.mu_weight, sigma_weight, self.prior_weight_mu, - self.prior_weight_sigma) + if self.jsg: + kl=self.jsg_div(self.mu_weight, sigma_weight, self.prior_weight_mu, + self.prior_weight_sigma) + + else: + kl = self.kl_div(self.mu_weight, sigma_weight, self.prior_weight_mu, + self.prior_weight_sigma) bias = None if self.mu_bias is not None: sigma_bias = torch.log1p(torch.exp(self.rho_bias)) bias = (sigma_bias * self.eps_bias.data.normal_()) if return_kl: - kl = kl + self.kl_div(self.mu_bias, sigma_bias, self.prior_bias_mu, + if self.jsg: + kl = kl + self.jsg_div(self.mu_bias, sigma_bias, self.prior_bias_mu, self.prior_bias_sigma) + else: + kl = kl + self.kl_div(self.mu_bias, sigma_bias, self.prior_bias_mu, + self.prior_bias_sigma) + # linear outputs outputs = F.linear(x, self.mu_weight, self.mu_bias) sign_input = x.clone().uniform_(-1, 1).sign() diff --git a/bayesian_torch/layers/flipout_layers/rnn_flipout.py b/bayesian_torch/layers/flipout_layers/rnn_flipout.py index 44d9043..bc5b317 100644 --- a/bayesian_torch/layers/flipout_layers/rnn_flipout.py +++ b/bayesian_torch/layers/flipout_layers/rnn_flipout.py @@ -51,6 +51,7 @@ def __init__(self, prior_variance=1, posterior_mu_init=0, posterior_rho_init=-3.0, + use_jsg=False, bias=True): """ Implements LSTM layer with reparameterization trick. @@ -75,6 +76,7 @@ def __init__(self, self.posterior_mu_init = posterior_mu_init, # mean of weight self.posterior_rho_init = posterior_rho_init, # variance of weight --> sigma = log (1 + exp(rho)) self.bias = bias + self.jsg=use_jsg self.kl = 0 diff --git a/bayesian_torch/layers/variational_layers/conv_variational.py b/bayesian_torch/layers/variational_layers/conv_variational.py index dd60d5f..44d349a 100644 --- a/bayesian_torch/layers/variational_layers/conv_variational.py +++ b/bayesian_torch/layers/variational_layers/conv_variational.py @@ -74,6 +74,7 @@ def __init__(self, prior_variance=1, posterior_mu_init=0, posterior_rho_init=-3.0, + use_jsg=False, bias=True): """ Implements Conv1d layer with reparameterization trick. @@ -113,6 +114,7 @@ def __init__(self, # variance of weight --> sigma = log (1 + exp(rho)) self.posterior_rho_init = posterior_rho_init, self.bias = bias + self.jsg=use_jsg self.mu_kernel = Parameter( torch.Tensor(out_channels, in_channels // groups, kernel_size)) @@ -190,18 +192,28 @@ def forward(self, input, return_kl=True): weight = self.mu_kernel + tmp_result if return_kl: - kl_weight = self.kl_div(self.mu_kernel, sigma_weight, - self.prior_weight_mu, self.prior_weight_sigma) - bias = None + if self.jsg: + kl_weight=self.jsg_div(self.mu_weight, sigma_weight, self.prior_weight_mu, + self.prior_weight_sigma) + + else: + kl_weight = self.kl_div(self.mu_kernel, sigma_weight, self.prior_weight_mu, + self.prior_weight_sigma) + bias = None if self.bias: sigma_bias = torch.log1p(torch.exp(self.rho_bias)) eps_bias = self.eps_bias.data.normal_() - bias = self.mu_bias + (sigma_bias * eps_bias) + bias = (sigma_bias * eps_bias) if return_kl: - kl_bias = self.kl_div(self.mu_bias, sigma_bias, self.prior_bias_mu, + if self.jsg: + kl_bias = self.jsg_div(self.mu_bias, sigma_bias, self.prior_bias_mu, self.prior_bias_sigma) + else: + kl_bias = self.kl_div(self.mu_bias, sigma_bias, self.prior_bias_mu, + self.prior_bias_sigma) + out = F.conv1d(input, weight, bias, self.stride, self.padding, self.dilation, self.groups) @@ -240,6 +252,7 @@ def __init__(self, prior_variance=1, posterior_mu_init=0, posterior_rho_init=-3.0, + use_jsg=False, bias=True): """ Implements Conv2d layer with reparameterization trick. @@ -280,6 +293,7 @@ def __init__(self, # variance of weight --> sigma = log (1 + exp(rho)) self.posterior_rho_init = posterior_rho_init, self.bias = bias + self.jsg=use_jsg kernel_size = get_kernel_size(kernel_size, 2) @@ -364,18 +378,28 @@ def forward(self, input, return_kl=True): weight = self.mu_kernel + tmp_result if return_kl: - kl_weight = self.kl_div(self.mu_kernel, sigma_weight, - self.prior_weight_mu, self.prior_weight_sigma) - bias = None + if self.jsg: + kl_weight=self.jsg_div(self.mu_weight, sigma_weight, self.prior_weight_mu, + self.prior_weight_sigma) + + else: + kl_weight = self.kl_div(self.mu_kernel, sigma_weight, self.prior_weight_mu, + self.prior_weight_sigma) + bias = None if self.bias: sigma_bias = torch.log1p(torch.exp(self.rho_bias)) eps_bias = self.eps_bias.data.normal_() - bias = self.mu_bias + (sigma_bias * eps_bias) + bias = (sigma_bias * eps_bias) if return_kl: - kl_bias = self.kl_div(self.mu_bias, sigma_bias, self.prior_bias_mu, + if self.jsg: + kl_bias = self.jsg_div(self.mu_bias, sigma_bias, self.prior_bias_mu, self.prior_bias_sigma) + else: + kl_bias = self.kl_div(self.mu_bias, sigma_bias, self.prior_bias_mu, + self.prior_bias_sigma) + out = F.conv2d(input, weight, bias, self.stride, self.padding, self.dilation, self.groups) @@ -415,6 +439,7 @@ def __init__(self, padding=0, dilation=1, groups=1, + use_jsg=False, bias=True): """ Implements Conv3d layer with reparameterization trick. @@ -455,6 +480,9 @@ def __init__(self, # variance of weight --> sigma = log (1 + exp(rho)) self.posterior_rho_init = posterior_rho_init, self.bias = bias + self.jsg=use_jsg + + kernel_size = get_kernel_size(kernel_size, 3) self.mu_kernel = Parameter( torch.Tensor(out_channels, in_channels // groups, kernel_size[0], @@ -537,18 +565,28 @@ def forward(self, input, return_kl=True): weight = self.mu_kernel + tmp_result if return_kl: - kl_weight = self.kl_div(self.mu_kernel, sigma_weight, - self.prior_weight_mu, self.prior_weight_sigma) - bias = None + if self.jsg: + kl_weight=self.jsg_div(self.mu_weight, sigma_weight, self.prior_weight_mu, + self.prior_weight_sigma) + + else: + kl_weight = self.kl_div(self.mu_kernel, sigma_weight, self.prior_weight_mu, + self.prior_weight_sigma) + bias = None if self.bias: sigma_bias = torch.log1p(torch.exp(self.rho_bias)) eps_bias = self.eps_bias.data.normal_() - bias = self.mu_bias + (sigma_bias * eps_bias) + bias = (sigma_bias * eps_bias) if return_kl: - kl_bias = self.kl_div(self.mu_bias, sigma_bias, self.prior_bias_mu, + if self.jsg: + kl_bias = self.jsg_div(self.mu_bias, sigma_bias, self.prior_bias_mu, self.prior_bias_sigma) + else: + kl_bias = self.kl_div(self.mu_bias, sigma_bias, self.prior_bias_mu, + self.prior_bias_sigma) + out = F.conv3d(input, weight, bias, self.stride, self.padding, self.dilation, self.groups) @@ -588,6 +626,7 @@ def __init__(self, prior_variance=1, posterior_mu_init=0, posterior_rho_init=-3.0, + use_jsg=False, bias=True): """ Implements ConvTranspose1d layer with reparameterization trick. @@ -628,6 +667,7 @@ def __init__(self, # variance of weight --> sigma = log (1 + exp(rho)) self.posterior_rho_init = posterior_rho_init, self.bias = bias + self.jsg=use_jsg self.mu_kernel = Parameter( torch.Tensor(in_channels, out_channels // groups, kernel_size)) @@ -705,18 +745,28 @@ def forward(self, input, return_kl=True): weight = self.mu_kernel + tmp_result if return_kl: - kl_weight = self.kl_div(self.mu_kernel, sigma_weight, - self.prior_weight_mu, self.prior_weight_sigma) - bias = None + if self.jsg: + kl_weight=self.jsg_div(self.mu_weight, sigma_weight, self.prior_weight_mu, + self.prior_weight_sigma) + + else: + kl_weight = self.kl_div(self.mu_kernel, sigma_weight, self.prior_weight_mu, + self.prior_weight_sigma) + bias = None if self.bias: sigma_bias = torch.log1p(torch.exp(self.rho_bias)) eps_bias = self.eps_bias.data.normal_() - bias = self.mu_bias + (sigma_bias * eps_bias) + bias = (sigma_bias * eps_bias) if return_kl: - kl_bias = self.kl_div(self.mu_bias, sigma_bias, self.prior_bias_mu, + if self.jsg: + kl_bias = self.jsg_div(self.mu_bias, sigma_bias, self.prior_bias_mu, self.prior_bias_sigma) + else: + kl_bias = self.kl_div(self.mu_bias, sigma_bias, self.prior_bias_mu, + self.prior_bias_sigma) + out = F.conv_transpose1d(input, weight, bias, self.stride, self.padding, self.output_padding, self.groups, self.dilation) @@ -758,6 +808,7 @@ def __init__(self, prior_variance=1, posterior_mu_init=0, posterior_rho_init=-3.0, + use_jsg=False, bias=True): """ Implements ConvTranspose2d layer with reparameterization trick. @@ -798,6 +849,8 @@ def __init__(self, # variance of weight --> sigma = log (1 + exp(rho)) self.posterior_rho_init = posterior_rho_init, self.bias = bias + self.jsg=use_jsg + kernel_size = get_kernel_size(kernel_size, 2) self.mu_kernel = Parameter( torch.Tensor(in_channels, out_channels // groups, kernel_size[0], @@ -880,18 +933,28 @@ def forward(self, input, return_kl=True): weight = self.mu_kernel + tmp_result if return_kl: - kl_weight = self.kl_div(self.mu_kernel, sigma_weight, - self.prior_weight_mu, self.prior_weight_sigma) - bias = None + if self.jsg: + kl_weight=self.jsg_div(self.mu_weight, sigma_weight, self.prior_weight_mu, + self.prior_weight_sigma) + + else: + kl_weight = self.kl_div(self.mu_kernel, sigma_weight, self.prior_weight_mu, + self.prior_weight_sigma) + bias = None if self.bias: sigma_bias = torch.log1p(torch.exp(self.rho_bias)) eps_bias = self.eps_bias.data.normal_() - bias = self.mu_bias + (sigma_bias * eps_bias) + bias = (sigma_bias * eps_bias) if return_kl: - kl_bias = self.kl_div(self.mu_bias, sigma_bias, self.prior_bias_mu, + if self.jsg: + kl_bias = self.jsg_div(self.mu_bias, sigma_bias, self.prior_bias_mu, self.prior_bias_sigma) + else: + kl_bias = self.kl_div(self.mu_bias, sigma_bias, self.prior_bias_mu, + self.prior_bias_sigma) + out = F.conv_transpose2d(input, weight, bias, self.stride, self.padding, self.output_padding, self.groups, self.dilation) @@ -933,6 +996,7 @@ def __init__(self, prior_variance=1, posterior_mu_init=0, posterior_rho_init=-3.0, + use_jsg=False, bias=True): """ Implements ConvTranspose3d layer with reparameterization trick. @@ -974,6 +1038,9 @@ def __init__(self, # variance of weight --> sigma = log (1 + exp(rho)) self.posterior_rho_init = posterior_rho_init, self.bias = bias + self.jsg=use_jsg + + kernel_size = get_kernel_size(kernel_size, 3) self.mu_kernel = Parameter( torch.Tensor(in_channels, out_channels // groups, kernel_size[0], @@ -1056,18 +1123,28 @@ def forward(self, input, return_kl=True): weight = self.mu_kernel + tmp_result if return_kl: - kl_weight = self.kl_div(self.mu_kernel, sigma_weight, - self.prior_weight_mu, self.prior_weight_sigma) - bias = None + if self.jsg: + kl_weight=self.jsg_div(self.mu_weight, sigma_weight, self.prior_weight_mu, + self.prior_weight_sigma) + + else: + kl_weight = self.kl_div(self.mu_kernel, sigma_weight, self.prior_weight_mu, + self.prior_weight_sigma) + bias = None if self.bias: sigma_bias = torch.log1p(torch.exp(self.rho_bias)) eps_bias = self.eps_bias.data.normal_() - bias = self.mu_bias + (sigma_bias * eps_bias) + bias = (sigma_bias * eps_bias) if return_kl: - kl_bias = self.kl_div(self.mu_bias, sigma_bias, self.prior_bias_mu, + if self.jsg: + kl_bias = self.jsg_div(self.mu_bias, sigma_bias, self.prior_bias_mu, self.prior_bias_sigma) + else: + kl_bias = self.kl_div(self.mu_bias, sigma_bias, self.prior_bias_mu, + self.prior_bias_sigma) + out = F.conv_transpose3d(input, weight, bias, self.stride, self.padding, self.output_padding, self.groups, self.dilation) diff --git a/bayesian_torch/layers/variational_layers/linear_variational.py b/bayesian_torch/layers/variational_layers/linear_variational.py index 04f2b6a..2a88eb3 100644 --- a/bayesian_torch/layers/variational_layers/linear_variational.py +++ b/bayesian_torch/layers/variational_layers/linear_variational.py @@ -59,6 +59,7 @@ def __init__(self, prior_variance=1, posterior_mu_init=0, posterior_rho_init=-3.0, + use_jsg=False, bias=True): """ Implements Linear layer with reparameterization trick. @@ -83,6 +84,7 @@ def __init__(self, self.posterior_mu_init = posterior_mu_init, # mean of weight # variance of weight --> sigma = log (1 + exp(rho)) self.posterior_rho_init = posterior_rho_init, + self.jsg=use_jsg self.bias = bias self.mu_weight = Parameter(torch.Tensor(out_features, in_features)) @@ -143,6 +145,9 @@ def init_parameters(self): def kl_loss(self): sigma_weight = torch.log1p(torch.exp(self.rho_weight)) + if self.jsg: + kl = self.jsg_div(self.mu_weight,sigma_weight,self.prior_weight_mu,self.prior_weight_sigma) + kl = self.kl_div( self.mu_weight, sigma_weight, @@ -150,8 +155,12 @@ def kl_loss(self): self.prior_weight_sigma) if self.mu_bias is not None: sigma_bias = torch.log1p(torch.exp(self.rho_bias)) - kl += self.kl_div(self.mu_bias, sigma_bias, + if self.jsg: + kl += self.jsg_div(self.mu_bias, sigma_bias, self.prior_bias_mu, self.prior_bias_sigma) + else: + kl += self.kl_div(self.mu_bias, sigma_bias, + self.prior_bias_mu, self.prior_bias_sigma) return kl def forward(self, input, return_kl=True): @@ -164,16 +173,25 @@ def forward(self, input, return_kl=True): if return_kl: - kl_weight = self.kl_div(self.mu_weight, sigma_weight, + if self.jsg: + kl_weight = self.jsg_div(self.mu_weight, sigma_weight, self.prior_weight_mu, self.prior_weight_sigma) + else: + kl_weight = self.kl_div(self.mu_weight, sigma_weight, + self.prior_weight_mu, self.prior_weight_sigma) bias = None if self.mu_bias is not None: sigma_bias = torch.log1p(torch.exp(self.rho_bias)) bias = self.mu_bias + (sigma_bias * self.eps_bias.data.normal_()) if return_kl: - kl_bias = self.kl_div(self.mu_bias, sigma_bias, self.prior_bias_mu, + if self.jsg: + kl_bias = self.jsg_div(self.mu_bias, sigma_bias, self.prior_bias_mu, self.prior_bias_sigma) + else: + kl_bias = self.kl_div(self.mu_bias, sigma_bias, self.prior_bias_mu, + self.prior_bias_sigma) + out = F.linear(input, weight, bias) diff --git a/bayesian_torch/layers/variational_layers/rnn_variational.py b/bayesian_torch/layers/variational_layers/rnn_variational.py index b2eda8a..1fd3c13 100644 --- a/bayesian_torch/layers/variational_layers/rnn_variational.py +++ b/bayesian_torch/layers/variational_layers/rnn_variational.py @@ -51,6 +51,7 @@ def __init__(self, prior_variance=1, posterior_mu_init=0, posterior_rho_init=-3.0, + use_jsg=False, bias=True): """ Implements LSTM layer with reparameterization trick. @@ -76,6 +77,7 @@ def __init__(self, # variance of weight --> sigma = log (1 + exp(rho)) self.posterior_rho_init = posterior_rho_init, self.bias = bias + self.jsg=use_jsg self.ih = LinearReparameterization( prior_mean=prior_mean,