Module src.vae.model

Module containing the encoders.

Expand source code
"""
Module containing the encoders.
"""
import numpy as np
import torch
from torch import nn

def init_specific_model(orig_dim, latent_dim, hidden_dim=6):
    """Return an instance of a VAE with encoder and decoder from `model_type`."""
    encoder = Encoder(orig_dim, latent_dim, hidden_dim)
    decoder = Decoder(orig_dim, latent_dim, hidden_dim)
    model = VAE(encoder, decoder)
    return model


class Encoder(nn.Module):
    def __init__(self, orig_dim=10, latent_dim=2, hidden_dim=6):
        r"""Encoder of the model for GMM samples
        """
        super(Encoder, self).__init__()
        # Layer parameters
        self.orig_dim = orig_dim
        self.latent_dim = latent_dim
        
        # Fully connected layers
        self.lin1 = nn.Linear(self.orig_dim, hidden_dim)
        self.lin2 = nn.Linear(hidden_dim, 2*hidden_dim)
        self.lin3 = nn.Linear(2*hidden_dim, hidden_dim)
        self.mu_logvar_gen = nn.Linear(hidden_dim, self.latent_dim * 2)
     
    def forward(self, x):
        # Fully connected layers with ReLu activations
        x = torch.relu(self.lin1(x))
        x = torch.relu(self.lin2(x))
        x = torch.relu(self.lin3(x))

        # Fully connected layer for log variance and mean
        # Log std-dev in paper (bear in mind)
        mu_logvar = self.mu_logvar_gen(x)
        mu, logvar = mu_logvar.view(-1, self.latent_dim, 2).unbind(-1)

        return mu, logvar    
    
        
class Decoder(nn.Module):
    def __init__(self, orig_dim=10, latent_dim=2, hidden_dim=6):
        r"""Decoder of the model for GMM samples
        """
        super(Decoder, self).__init__()
        # Layer parameters
        self.orig_dim = orig_dim
        self.latent_dim = latent_dim
        
        # Fully connected layers
        self.lin1 = nn.Linear(latent_dim, hidden_dim)
        self.lin2 = nn.Linear(hidden_dim, 2*hidden_dim)
        self.lin3 = nn.Linear(2*hidden_dim, hidden_dim)
        self.lin4 = nn.Linear(hidden_dim, orig_dim)

    def forward(self, z):
        # Fully connected layers with ReLu activations
        x = torch.relu(self.lin1(z))
        x = torch.relu(self.lin2(x))
        x = torch.relu(self.lin3(x))
        x = self.lin4(x)

        return x     
    
    
class VAE(nn.Module):
    def __init__(self, encoder, decoder):
        """
        Class which defines model and forward pass.
        """
        super(VAE, self).__init__()

        self.encoder = encoder
        self.decoder = decoder

    def reparameterize(self, mean, logvar):
        """
        Samples from a normal distribution using the reparameterization trick.

        Parameters
        ----------
        mean : torch.Tensor
            Mean of the normal distribution. Shape (batch_size, latent_dim)

        logvar : torch.Tensor
            Diagonal log variance of the normal distribution. Shape (batch_size,
            latent_dim)
        """
        if self.training:
            std = torch.exp(0.5 * logvar)
            eps = torch.randn_like(std)
            return mean + std * eps
        else:
            # Reconstruction mode
            return mean

    def forward(self, x):
        """
        Forward pass of model.

        Parameters
        ----------
        x : torch.Tensor
            Batch of data. Shape (batch_size, n_chan, height, width)
        """
        latent_dist = self.encoder(x)
        latent_sample = self.reparameterize(*latent_dist)
        reconstruct = self.decoder(latent_sample)
        return reconstruct, latent_dist, latent_sample

    def sample_latent(self, x):
        """
        Returns a sample from the latent distribution.

        Parameters
        ----------
        x : torch.Tensor
            Batch of data. Shape (batch_size, n_chan, height, width)
        """
        latent_dist = self.encoder(x)
        latent_sample = self.reparameterize(*latent_dist)
        return latent_sample    

Functions

def init_specific_model(orig_dim, latent_dim, hidden_dim=6)

Return an instance of a VAE with encoder and decoder from model_type.

Expand source code
def init_specific_model(orig_dim, latent_dim, hidden_dim=6):
    """Return an instance of a VAE with encoder and decoder from `model_type`."""
    encoder = Encoder(orig_dim, latent_dim, hidden_dim)
    decoder = Decoder(orig_dim, latent_dim, hidden_dim)
    model = VAE(encoder, decoder)
    return model

Classes

class Decoder (orig_dim=10, latent_dim=2, hidden_dim=6)

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will have their parameters converted too when you call :meth:to, etc.

Decoder of the model for GMM samples

Expand source code
class Decoder(nn.Module):
    def __init__(self, orig_dim=10, latent_dim=2, hidden_dim=6):
        r"""Decoder of the model for GMM samples
        """
        super(Decoder, self).__init__()
        # Layer parameters
        self.orig_dim = orig_dim
        self.latent_dim = latent_dim
        
        # Fully connected layers
        self.lin1 = nn.Linear(latent_dim, hidden_dim)
        self.lin2 = nn.Linear(hidden_dim, 2*hidden_dim)
        self.lin3 = nn.Linear(2*hidden_dim, hidden_dim)
        self.lin4 = nn.Linear(hidden_dim, orig_dim)

    def forward(self, z):
        # Fully connected layers with ReLu activations
        x = torch.relu(self.lin1(z))
        x = torch.relu(self.lin2(x))
        x = torch.relu(self.lin3(x))
        x = self.lin4(x)

        return x     

Ancestors

  • torch.nn.modules.module.Module

Methods

def forward(self, z)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the :class:Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Expand source code
def forward(self, z):
    # Fully connected layers with ReLu activations
    x = torch.relu(self.lin1(z))
    x = torch.relu(self.lin2(x))
    x = torch.relu(self.lin3(x))
    x = self.lin4(x)

    return x     
class Encoder (orig_dim=10, latent_dim=2, hidden_dim=6)

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will have their parameters converted too when you call :meth:to, etc.

Encoder of the model for GMM samples

Expand source code
class Encoder(nn.Module):
    def __init__(self, orig_dim=10, latent_dim=2, hidden_dim=6):
        r"""Encoder of the model for GMM samples
        """
        super(Encoder, self).__init__()
        # Layer parameters
        self.orig_dim = orig_dim
        self.latent_dim = latent_dim
        
        # Fully connected layers
        self.lin1 = nn.Linear(self.orig_dim, hidden_dim)
        self.lin2 = nn.Linear(hidden_dim, 2*hidden_dim)
        self.lin3 = nn.Linear(2*hidden_dim, hidden_dim)
        self.mu_logvar_gen = nn.Linear(hidden_dim, self.latent_dim * 2)
     
    def forward(self, x):
        # Fully connected layers with ReLu activations
        x = torch.relu(self.lin1(x))
        x = torch.relu(self.lin2(x))
        x = torch.relu(self.lin3(x))

        # Fully connected layer for log variance and mean
        # Log std-dev in paper (bear in mind)
        mu_logvar = self.mu_logvar_gen(x)
        mu, logvar = mu_logvar.view(-1, self.latent_dim, 2).unbind(-1)

        return mu, logvar    

Ancestors

  • torch.nn.modules.module.Module

Methods

def forward(self, x)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the :class:Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Expand source code
def forward(self, x):
    # Fully connected layers with ReLu activations
    x = torch.relu(self.lin1(x))
    x = torch.relu(self.lin2(x))
    x = torch.relu(self.lin3(x))

    # Fully connected layer for log variance and mean
    # Log std-dev in paper (bear in mind)
    mu_logvar = self.mu_logvar_gen(x)
    mu, logvar = mu_logvar.view(-1, self.latent_dim, 2).unbind(-1)

    return mu, logvar    
class VAE (encoder, decoder)

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will have their parameters converted too when you call :meth:to, etc.

Class which defines model and forward pass.

Expand source code
class VAE(nn.Module):
    def __init__(self, encoder, decoder):
        """
        Class which defines model and forward pass.
        """
        super(VAE, self).__init__()

        self.encoder = encoder
        self.decoder = decoder

    def reparameterize(self, mean, logvar):
        """
        Samples from a normal distribution using the reparameterization trick.

        Parameters
        ----------
        mean : torch.Tensor
            Mean of the normal distribution. Shape (batch_size, latent_dim)

        logvar : torch.Tensor
            Diagonal log variance of the normal distribution. Shape (batch_size,
            latent_dim)
        """
        if self.training:
            std = torch.exp(0.5 * logvar)
            eps = torch.randn_like(std)
            return mean + std * eps
        else:
            # Reconstruction mode
            return mean

    def forward(self, x):
        """
        Forward pass of model.

        Parameters
        ----------
        x : torch.Tensor
            Batch of data. Shape (batch_size, n_chan, height, width)
        """
        latent_dist = self.encoder(x)
        latent_sample = self.reparameterize(*latent_dist)
        reconstruct = self.decoder(latent_sample)
        return reconstruct, latent_dist, latent_sample

    def sample_latent(self, x):
        """
        Returns a sample from the latent distribution.

        Parameters
        ----------
        x : torch.Tensor
            Batch of data. Shape (batch_size, n_chan, height, width)
        """
        latent_dist = self.encoder(x)
        latent_sample = self.reparameterize(*latent_dist)
        return latent_sample    

Ancestors

  • torch.nn.modules.module.Module

Methods

def forward(self, x)

Forward pass of model.

Parameters

x : torch.Tensor
Batch of data. Shape (batch_size, n_chan, height, width)
Expand source code
def forward(self, x):
    """
    Forward pass of model.

    Parameters
    ----------
    x : torch.Tensor
        Batch of data. Shape (batch_size, n_chan, height, width)
    """
    latent_dist = self.encoder(x)
    latent_sample = self.reparameterize(*latent_dist)
    reconstruct = self.decoder(latent_sample)
    return reconstruct, latent_dist, latent_sample
def reparameterize(self, mean, logvar)

Samples from a normal distribution using the reparameterization trick.

Parameters

mean : torch.Tensor
Mean of the normal distribution. Shape (batch_size, latent_dim)
logvar : torch.Tensor
Diagonal log variance of the normal distribution. Shape (batch_size, latent_dim)
Expand source code
def reparameterize(self, mean, logvar):
    """
    Samples from a normal distribution using the reparameterization trick.

    Parameters
    ----------
    mean : torch.Tensor
        Mean of the normal distribution. Shape (batch_size, latent_dim)

    logvar : torch.Tensor
        Diagonal log variance of the normal distribution. Shape (batch_size,
        latent_dim)
    """
    if self.training:
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mean + std * eps
    else:
        # Reconstruction mode
        return mean
def sample_latent(self, x)

Returns a sample from the latent distribution.

Parameters

x : torch.Tensor
Batch of data. Shape (batch_size, n_chan, height, width)
Expand source code
def sample_latent(self, x):
    """
    Returns a sample from the latent distribution.

    Parameters
    ----------
    x : torch.Tensor
        Batch of data. Shape (batch_size, n_chan, height, width)
    """
    latent_dist = self.encoder(x)
    latent_sample = self.reparameterize(*latent_dist)
    return latent_sample