Module src.vae.losses

Expand source code
import abc
import math
import os, sys
sys.path.append('../../')
import numpy as np

import torch
import torch.nn as nn
from torch.nn import functional as F
from torch import optim
from src.vae.utils import matrix_log_density_gaussian, log_density_gaussian, log_importance_weight_matrix, logsumexp
from src.vae.loss_hessian import hessian_penalty


class Loss(abc.ABC):
    """
    """
    def __init__(self, beta=0., mu=0., lamPT=0., lamCI=0., lam_nearest_neighbor=0.,
                 alpha=0., gamma=0., tc=0., is_mss=True):
        """
        Parameters
        ----------
        beta : float
            Hyperparameter for beta-VAE term.

        mu : float
            Hyperparameter for latent distribution mean.
            
        lamPT : float
            Hyperparameter for penalizing change in one latent induced by another.
            
        lamCI : float
            Hyperparameter for penalizing change in conditional distribution p(z_-j | z_j).
            
        lam_nearest_neighbor : float
            Hyperparameter for penalizing distance to nearest neighbors in each batch
            
        alpha : float
            Hyperparameter for mutual information term.
            
        gamma: float
            Hyperparameter for dimension-wise KL term.
            
        tc: float
            Hyperparameter for total correlation term.
        """           
        self.beta = beta
        self.mu = mu
        self.lamPT = lamPT
        self.lamCI = lamCI
        self.lam_nearest_neighbor = lam_nearest_neighbor        
        self.alpha = alpha
        self.gamma = gamma
        self.tc = tc
        self.is_mss = is_mss

    def __call__(self, data, recon_data, latent_dist, latent_sample, n_data, latent_output=None):
        """
        Parameters
        ----------
        data : torch.Tensor
            Input data (e.g. batch of images). Shape : (batch_size, n_chan,
            height, width).

        recon_data : torch.Tensor
            Reconstructed data. Shape : (batch_size, n_chan, height, width).
            
        latent_dist: list of torch.Tensor
            Encoder latent distribution [mean, logvar]. Shape : (batch_size, latent_dim).
            
        latent_sample: torch.Tensor
            Latent samples. Shape : (batch_size, latent_dim).
            
        n_data: int
            Total number of training examples. 
            
        latent_output: torch.Tensor, optional
            Output of the Decoder->Encoder mapping of latent sample. Shape : (batch_size, latent_dim).

        Return
        ------
        loss : torch.Tensor
        """        
        batch_size, latent_dim = latent_sample.shape
        
        self.rec_loss = _reconstruction_loss(data, recon_data)
        self.kl_loss = _kl_normal_loss(*latent_dist) 
        self.mu_loss = _kl_normal_loss(latent_dist[0], torch.zeros_like(latent_dist[1])) 

        log_pz, log_qz, log_qzi, log_prod_qzi, log_q_zCx = _get_log_pz_qz_prodzi_qzCx(latent_sample,
                                                                                      latent_dist,
                                                                                      n_data,
                                                                                      is_mss=self.is_mss)      
        # I[z;x] = KL[q(z,x)||q(x)q(z)] = E_x[KL[q(z|x)||q(z)]]
        self.mi_loss = (log_q_zCx - log_qz).mean()
        # TC[z] = KL[q(z)||\prod_i z_i]
        self.tc_loss = (log_qz - log_prod_qzi).mean()
        # dw_kl_loss is KL[q(z)||p(z)] instead of usual KL[q(z|x)||p(z))]
        self.dw_kl_loss = (log_prod_qzi - log_pz).mean()           

        # total loss
        loss = self.rec_loss + (self.beta * self.kl_loss +
                                self.mu * self.mu_loss)        
        
        # pointwise independence loss
        self.pt_loss = 0
        if self.lamPT > 0 and latent_output is not None:
            for i in range(latent_dim):
                col_idx = np.arange(latent_dim)!=i
                gradients = torch.autograd.grad(latent_output[:,i], latent_sample, grad_outputs=torch.ones_like(latent_output[:,i]), 
                                                retain_graph=True, create_graph=True, only_inputs=True)[0][:,col_idx]   
                self.pt_loss += abs(gradients).mean()
            loss += self.lamPT * self.pt_loss
        
        # local independence loss
        self.ci_loss = 0
        if self.lamCI > 0:
            log_q_zCzi = log_qz.view(batch_size, 1) - log_qzi
            for i in range(latent_dim):
                gradients = torch.autograd.grad(log_q_zCzi[:,i], latent_sample, grad_outputs=torch.ones_like(log_q_zCzi[:,i]), 
                                                retain_graph=True, create_graph=True, only_inputs=True)[0][:,i] 
                self.ci_loss += abs(gradients).mean()     
            loss += self.lamCI * self.ci_loss        
        
        # nearest-neighbor batch loss
        self.nearest_neighbor_loss = 0
        if self.lam_nearest_neighbor > 0:
            for i in range(latent_dim):
                dists = torch.pairwise_distance(latent_sample[i], latent_sample)
                self.nearest_neighbor_loss += dists.min()
            loss += self.lam_nearest_neighbor * self.nearest_neighbor_loss
        
        return loss
    
    
def _reconstruction_loss(data, recon_data):
    """
    Parameters
    ----------
    data : torch.Tensor
        Input data (e.g. batch of images). Shape : (batch_size, n_chan,
        height, width).

    recon_data : torch.Tensor
        Reconstructed data. Shape : (batch_size, n_chan, height, width).

    Returns
    -------
    loss : torch.Tensor
    """
    batch_size, dim = recon_data.size()

    loss = F.mse_loss(recon_data, data, reduction="sum") 
    loss = loss / batch_size

    return loss


def _kl_normal_loss(mean, logvar):
    """
    Calculates the KL divergence between a normal distribution
    with diagonal covariance and a unit normal distribution.

    Parameters
    ----------
    mean : torch.Tensor
        Mean of the normal distribution. Shape (batch_size, latent_dim) where
        D is dimension of distribution.

    logvar : torch.Tensor
        Diagonal log variance of the normal distribution. Shape (batch_size,
        latent_dim)
    """
    latent_dim = mean.size(1)
    # batch mean of kl for each latent dimension
    latent_kl = 0.5 * (-1 - logvar + mean.pow(2) + logvar.exp()).mean(dim=0)
    total_kl = latent_kl.sum()

    return total_kl   


def _get_log_pz_qz_prodzi_qzCx(latent_sample, latent_dist, n_data, is_mss=True):
    batch_size, hidden_dim = latent_sample.shape

    # calculate log q(z|x)
    log_q_zCx = log_density_gaussian(latent_sample, *latent_dist).sum(dim=1)

    # calculate log p(z)
    # mean and log var is 0
    zeros = torch.zeros_like(latent_sample)
    log_pz = log_density_gaussian(latent_sample, zeros, zeros).sum(1)
    
    # calculate log q(z)
    mat_log_qz = matrix_log_density_gaussian(latent_sample, *latent_dist)

    if is_mss:
        # use stratification
        log_iw_mat = log_importance_weight_matrix(batch_size, n_data).to(latent_sample.device)
        mat_log_qzi = mat_log_qz + log_iw_mat.view(batch_size, batch_size, 1)
        
    log_qz = logsumexp(mat_log_qz.sum(2) + log_iw_mat, dim=1, keepdim=False)
    log_qzi = logsumexp(mat_log_qzi, dim=1, keepdim=False)
    log_prod_qzi = log_qzi.sum(1)

    return log_pz, log_qz, log_qzi, log_prod_qzi, log_q_zCx


def _get_log_qz_qzi_perb(latent_sample_perb, latent_dist, n_data, is_mss=True):
    batch_size, hidden_dim, perb_size = latent_sample_perb.shape
    mu, logvar = latent_dist
    
    latent_sample_perb = latent_sample_perb.view(batch_size, 1, hidden_dim, perb_size)    
    mu = mu.view(1, batch_size, hidden_dim, 1)
    logvar = logvar.view(1, batch_size, hidden_dim, 1)
    
    # calculate log q(z)
    mat_log_qz = log_density_gaussian(latent_sample_perb, mu, logvar)

    if is_mss:
        # use stratification
        log_iw_mat = log_importance_weight_matrix(batch_size, n_data).to(latent_sample_perb.device)
        mat_log_qzi = mat_log_qz + log_iw_mat.view(batch_size, batch_size, 1, 1)
        
    log_qz = logsumexp(mat_log_qz.sum(2) + log_iw_mat.view(batch_size, batch_size, 1), dim=1, keepdim=False)
    log_qzi = logsumexp(mat_log_qzi, dim=1, keepdim=False)

    return log_qz, log_qzi

            
    
    

Classes

class Loss (beta=0.0, mu=0.0, lamPT=0.0, lamCI=0.0, lam_nearest_neighbor=0.0, alpha=0.0, gamma=0.0, tc=0.0, is_mss=True)

Parameters

beta : float
Hyperparameter for beta-VAE term.
mu : float
Hyperparameter for latent distribution mean.
lamPT : float
Hyperparameter for penalizing change in one latent induced by another.
lamCI : float
Hyperparameter for penalizing change in conditional distribution p(z_-j | z_j).
lam_nearest_neighbor : float
Hyperparameter for penalizing distance to nearest neighbors in each batch
alpha : float
Hyperparameter for mutual information term.
gamma : float
Hyperparameter for dimension-wise KL term.
tc : float
Hyperparameter for total correlation term.
Expand source code
class Loss(abc.ABC):
    """
    """
    def __init__(self, beta=0., mu=0., lamPT=0., lamCI=0., lam_nearest_neighbor=0.,
                 alpha=0., gamma=0., tc=0., is_mss=True):
        """
        Parameters
        ----------
        beta : float
            Hyperparameter for beta-VAE term.

        mu : float
            Hyperparameter for latent distribution mean.
            
        lamPT : float
            Hyperparameter for penalizing change in one latent induced by another.
            
        lamCI : float
            Hyperparameter for penalizing change in conditional distribution p(z_-j | z_j).
            
        lam_nearest_neighbor : float
            Hyperparameter for penalizing distance to nearest neighbors in each batch
            
        alpha : float
            Hyperparameter for mutual information term.
            
        gamma: float
            Hyperparameter for dimension-wise KL term.
            
        tc: float
            Hyperparameter for total correlation term.
        """           
        self.beta = beta
        self.mu = mu
        self.lamPT = lamPT
        self.lamCI = lamCI
        self.lam_nearest_neighbor = lam_nearest_neighbor        
        self.alpha = alpha
        self.gamma = gamma
        self.tc = tc
        self.is_mss = is_mss

    def __call__(self, data, recon_data, latent_dist, latent_sample, n_data, latent_output=None):
        """
        Parameters
        ----------
        data : torch.Tensor
            Input data (e.g. batch of images). Shape : (batch_size, n_chan,
            height, width).

        recon_data : torch.Tensor
            Reconstructed data. Shape : (batch_size, n_chan, height, width).
            
        latent_dist: list of torch.Tensor
            Encoder latent distribution [mean, logvar]. Shape : (batch_size, latent_dim).
            
        latent_sample: torch.Tensor
            Latent samples. Shape : (batch_size, latent_dim).
            
        n_data: int
            Total number of training examples. 
            
        latent_output: torch.Tensor, optional
            Output of the Decoder->Encoder mapping of latent sample. Shape : (batch_size, latent_dim).

        Return
        ------
        loss : torch.Tensor
        """        
        batch_size, latent_dim = latent_sample.shape
        
        self.rec_loss = _reconstruction_loss(data, recon_data)
        self.kl_loss = _kl_normal_loss(*latent_dist) 
        self.mu_loss = _kl_normal_loss(latent_dist[0], torch.zeros_like(latent_dist[1])) 

        log_pz, log_qz, log_qzi, log_prod_qzi, log_q_zCx = _get_log_pz_qz_prodzi_qzCx(latent_sample,
                                                                                      latent_dist,
                                                                                      n_data,
                                                                                      is_mss=self.is_mss)      
        # I[z;x] = KL[q(z,x)||q(x)q(z)] = E_x[KL[q(z|x)||q(z)]]
        self.mi_loss = (log_q_zCx - log_qz).mean()
        # TC[z] = KL[q(z)||\prod_i z_i]
        self.tc_loss = (log_qz - log_prod_qzi).mean()
        # dw_kl_loss is KL[q(z)||p(z)] instead of usual KL[q(z|x)||p(z))]
        self.dw_kl_loss = (log_prod_qzi - log_pz).mean()           

        # total loss
        loss = self.rec_loss + (self.beta * self.kl_loss +
                                self.mu * self.mu_loss)        
        
        # pointwise independence loss
        self.pt_loss = 0
        if self.lamPT > 0 and latent_output is not None:
            for i in range(latent_dim):
                col_idx = np.arange(latent_dim)!=i
                gradients = torch.autograd.grad(latent_output[:,i], latent_sample, grad_outputs=torch.ones_like(latent_output[:,i]), 
                                                retain_graph=True, create_graph=True, only_inputs=True)[0][:,col_idx]   
                self.pt_loss += abs(gradients).mean()
            loss += self.lamPT * self.pt_loss
        
        # local independence loss
        self.ci_loss = 0
        if self.lamCI > 0:
            log_q_zCzi = log_qz.view(batch_size, 1) - log_qzi
            for i in range(latent_dim):
                gradients = torch.autograd.grad(log_q_zCzi[:,i], latent_sample, grad_outputs=torch.ones_like(log_q_zCzi[:,i]), 
                                                retain_graph=True, create_graph=True, only_inputs=True)[0][:,i] 
                self.ci_loss += abs(gradients).mean()     
            loss += self.lamCI * self.ci_loss        
        
        # nearest-neighbor batch loss
        self.nearest_neighbor_loss = 0
        if self.lam_nearest_neighbor > 0:
            for i in range(latent_dim):
                dists = torch.pairwise_distance(latent_sample[i], latent_sample)
                self.nearest_neighbor_loss += dists.min()
            loss += self.lam_nearest_neighbor * self.nearest_neighbor_loss
        
        return loss

Ancestors

  • abc.ABC