Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added Skew geometric jensen-shannon divergence #35

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions bayesian_torch/layers/base_variational_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,3 +66,35 @@ def kl_div(self, mu_q, sigma_q, mu_p, sigma_p):
sigma_q) + (sigma_q**2 + (mu_q - mu_p)**2) / (2 *
(sigma_p**2)) - 0.5
return kl.mean()

def jsg_div(self, mu_q, sigma_q, mu_p, sigma_p, alpha=0.5):
'''
Calculates skew geometric jenson-shannon divergence between two gaussians (Q||P)

Parameters:
* mu_q: torch.Tensor -> mu parameter of distribution Q
* sigma_q: torch.Tensor -> sigma parameter of distribution Q
* mu_p: float -> mu parameter of distribution P
* sigma_p: float -> sigma parameter of distribution P

returns torch.Tensor of shape 0
'''

sigma_0_alpha = (sigma_q.pow(2) * sigma_p.pow(2)) \
/ ((1-alpha)*sigma_q.pow(2) + alpha*sigma_p.pow(2))

mu_0_alpha = sigma_0_alpha * ((alpha*mu_q/sigma_q.pow(2)) \
+ ((1-alpha)*mu_p/(sigma_p.pow(2))))

term1 = ((1-alpha)*sigma_q.pow(2) + alpha*sigma_p.pow(2)) / sigma_0_alpha

term2 = torch.log(sigma_0_alpha / (torch.pow(sigma_q, 2-2*alpha) \
* sigma_p.pow(2*alpha)))

term3 = (1-alpha)*(mu_0_alpha - mu_q).pow(2) / sigma_0_alpha

term4 = alpha*(mu_0_alpha - mu_p).pow(2) / sigma_0_alpha

jsg_divergence = 0.5 * (term1 + term2 + term3 + term4 - 1)

return jsg_divergence.mean()
57 changes: 47 additions & 10 deletions bayesian_torch/layers/flipout_layers/conv_flipout.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def __init__(self,
prior_variance=1,
posterior_mu_init=0,
posterior_rho_init=-3.0,
use_jsg=False,
bias=True):
"""
Implements Conv1d layer with Flipout reparameterization trick.
Expand Down Expand Up @@ -102,6 +103,7 @@ def __init__(self,
self.posterior_mu_init = posterior_mu_init
self.posterior_rho_init = posterior_rho_init
self.bias = bias
self.jsg=use_jsg

self.kl = 0

Expand Down Expand Up @@ -197,18 +199,28 @@ def forward(self, x, return_kl=True):
delta_kernel = (sigma_weight * eps_kernel)

if return_kl:
kl = self.kl_div(self.mu_kernel, sigma_weight, self.prior_weight_mu,
self.prior_weight_sigma)
if self.jsg:
kl=self.jsg_div(self.mu_weight, sigma_weight, self.prior_weight_mu,
self.prior_weight_sigma)

else:
kl = self.kl_div(self.mu_kernel, sigma_weight, self.prior_weight_mu,
self.prior_weight_sigma)

bias = None
if self.bias:
sigma_bias = torch.log1p(torch.exp(self.rho_bias))
eps_bias = self.eps_bias.data.normal_()
bias = (sigma_bias * eps_bias)
if return_kl:
kl = kl + self.kl_div(self.mu_bias, sigma_bias, self.prior_bias_mu,
if self.jsg:
kl = kl + self.jsg_div(self.mu_bias, sigma_bias, self.prior_bias_mu,
self.prior_bias_sigma)

else:
kl = kl + self.kl_div(self.mu_bias, sigma_bias, self.prior_bias_mu,
self.prior_bias_sigma)

# perturbed feedforward
x_tmp = x * sign_input
perturbed_outputs_tmp = F.conv1d(x * sign_input,
Expand Down Expand Up @@ -257,6 +269,7 @@ def __init__(self,
prior_variance=1,
posterior_mu_init=0,
posterior_rho_init=-3.0,
use_jsg=False,
bias=True):
"""
Implements Conv2d layer with Flipout reparameterization trick.
Expand Down Expand Up @@ -292,6 +305,7 @@ def __init__(self,
self.posterior_mu_init = posterior_mu_init
self.posterior_rho_init = posterior_rho_init
self.bias = bias
self.jsg=use_jsg

self.kl = 0
kernel_size = get_kernel_size(kernel_size, 2)
Expand Down Expand Up @@ -390,20 +404,30 @@ def forward(self, x, return_kl=True):
eps_kernel = self.eps_kernel.data.normal_()

delta_kernel = (sigma_weight * eps_kernel)

if return_kl:
kl = self.kl_div(self.mu_kernel, sigma_weight, self.prior_weight_mu,
self.prior_weight_sigma)
if self.jsg:
kl=self.jsg_div(self.mu_weight, sigma_weight, self.prior_weight_mu,
self.prior_weight_sigma)

else:
kl = self.kl_div(self.mu_kernel, sigma_weight, self.prior_weight_mu,
self.prior_weight_sigma)

bias = None
if self.bias:
sigma_bias = torch.log1p(torch.exp(self.rho_bias))
eps_bias = self.eps_bias.data.normal_()
bias = (sigma_bias * eps_bias)
if return_kl:
kl = kl + self.kl_div(self.mu_bias, sigma_bias, self.prior_bias_mu,
if self.jsg:
kl = kl + self.jsg_div(self.mu_bias, sigma_bias, self.prior_bias_mu,
self.prior_bias_sigma)

else:
kl = kl + self.kl_div(self.mu_bias, sigma_bias, self.prior_bias_mu,
self.prior_bias_sigma)

# perturbed feedforward
x_tmp = x * sign_input
perturbed_outputs_tmp = F.conv2d(x * sign_input,
Expand Down Expand Up @@ -453,6 +477,7 @@ def __init__(self,
prior_variance=1,
posterior_mu_init=0,
posterior_rho_init=-3.0,
use_jsg=False,
bias=True):
"""
Implements Conv3d layer with Flipout reparameterization trick.
Expand Down Expand Up @@ -483,6 +508,7 @@ def __init__(self,
self.dilation = dilation
self.groups = groups
self.bias = bias
self.jsg=use_jsg

self.kl = 0

Expand Down Expand Up @@ -590,18 +616,29 @@ def forward(self, x, return_kl=True):
delta_kernel = (sigma_weight * eps_kernel)

if return_kl:
kl = self.kl_div(self.mu_kernel, sigma_weight, self.prior_weight_mu,
self.prior_weight_sigma)
if self.jsg:
kl=self.jsg_div(self.mu_weight, sigma_weight, self.prior_weight_mu,
self.prior_weight_sigma)

else:
kl = self.kl_div(self.mu_kernel, sigma_weight, self.prior_weight_mu,
self.prior_weight_sigma)

bias = None
if self.bias:
sigma_bias = torch.log1p(torch.exp(self.rho_bias))
eps_bias = self.eps_bias.data.normal_()
bias = (sigma_bias * eps_bias)
if return_kl:
kl = kl + self.kl_div(self.mu_bias, sigma_bias, self.prior_bias_mu,
if self.jsg:
kl = kl + self.jsg_div(self.mu_bias, sigma_bias, self.prior_bias_mu,
self.prior_bias_sigma)

else:
kl = kl + self.kl_div(self.mu_bias, sigma_bias, self.prior_bias_mu,
self.prior_bias_sigma)


# perturbed feedforward
x_tmp = x * sign_input
perturbed_outputs_tmp = F.conv3d(x * sign_input,
Expand Down
29 changes: 24 additions & 5 deletions bayesian_torch/layers/flipout_layers/linear_flipout.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def __init__(self,
prior_variance=1,
posterior_mu_init=0,
posterior_rho_init=-3.0,
use_jsg=False,
bias=True):
"""
Implements Linear layer with Flipout reparameterization trick.
Expand Down Expand Up @@ -91,6 +92,8 @@ def __init__(self,
self.register_buffer('prior_weight_sigma',
torch.Tensor(out_features, in_features),
persistent=False)

self.jsg=use_jsg

if bias:
self.mu_bias = nn.Parameter(torch.Tensor(out_features))
Expand Down Expand Up @@ -136,10 +139,16 @@ def init_parameters(self):

def kl_loss(self):
sigma_weight = torch.log1p(torch.exp(self.rho_weight))
kl = self.kl_div(self.mu_weight, sigma_weight, self.prior_weight_mu, self.prior_weight_sigma)
if self.jsg:
kl = self.jsg_div(self.mu_weight, sigma_weight, self.prior_weight_mu, self.prior_weight_sigma)
else:
kl = self.kl_div(self.mu_weight, sigma_weight, self.prior_weight_mu, self.prior_weight_sigma)
if self.mu_bias is not None:
sigma_bias = torch.log1p(torch.exp(self.rho_bias))
kl += self.kl_div(self.mu_bias, sigma_bias, self.prior_bias_mu, self.prior_bias_sigma)
if self.jsg:
kl += self.kl_div(self.mu_bias, sigma_bias, self.prior_bias_mu, self.prior_bias_sigma)
else:
kl += self.kl_div(self.mu_bias, sigma_bias, self.prior_bias_mu, self.prior_bias_sigma)
return kl

def forward(self, x, return_kl=True):
Expand All @@ -153,17 +162,27 @@ def forward(self, x, return_kl=True):

# get kl divergence
if return_kl:
kl = self.kl_div(self.mu_weight, sigma_weight, self.prior_weight_mu,
self.prior_weight_sigma)
if self.jsg:
kl=self.jsg_div(self.mu_weight, sigma_weight, self.prior_weight_mu,
self.prior_weight_sigma)

else:
kl = self.kl_div(self.mu_weight, sigma_weight, self.prior_weight_mu,
self.prior_weight_sigma)

bias = None
if self.mu_bias is not None:
sigma_bias = torch.log1p(torch.exp(self.rho_bias))
bias = (sigma_bias * self.eps_bias.data.normal_())
if return_kl:
kl = kl + self.kl_div(self.mu_bias, sigma_bias, self.prior_bias_mu,
if self.jsg:
kl = kl + self.jsg_div(self.mu_bias, sigma_bias, self.prior_bias_mu,
self.prior_bias_sigma)

else:
kl = kl + self.kl_div(self.mu_bias, sigma_bias, self.prior_bias_mu,
self.prior_bias_sigma)

# linear outputs
outputs = F.linear(x, self.mu_weight, self.mu_bias)
sign_input = x.clone().uniform_(-1, 1).sign()
Expand Down
2 changes: 2 additions & 0 deletions bayesian_torch/layers/flipout_layers/rnn_flipout.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def __init__(self,
prior_variance=1,
posterior_mu_init=0,
posterior_rho_init=-3.0,
use_jsg=False,
bias=True):
"""
Implements LSTM layer with reparameterization trick.
Expand All @@ -75,6 +76,7 @@ def __init__(self,
self.posterior_mu_init = posterior_mu_init, # mean of weight
self.posterior_rho_init = posterior_rho_init, # variance of weight --> sigma = log (1 + exp(rho))
self.bias = bias
self.jsg=use_jsg

self.kl = 0

Expand Down
Loading