Module src.vae.training
Expand source code
import numpy as np
import os, sys
from tqdm import trange
import torch
from torch.nn import functional as F
from copy import deepcopy
# trim modules
sys.path.append('../../lib/trim')
sys.path.append('lib/trim')
from trim import DecoderEncoder
class Trainer():
"""
Class to handle training of model.
"""
def __init__(self, model, optimizer, loss_f,
device=torch.device("cpu"),
use_residuals=True):
self.device = device
self.model = model.to(self.device)
self.loss_f = loss_f
self.optimizer = optimizer
self.use_residuals = use_residuals
self._create_latent_map()
def __call__(self, train_loader, test_loader, epochs=10):
"""
Trains the model.
Parameters
----------
data_loader: torch.utils.data.DataLoader
epochs: int, optional
Number of epochs to train the model for.
checkpoint_every: int, optional
Save a checkpoint of the trained model every n epoch.
"""
for epoch in range(epochs):
mean_epoch_loss = self._train_epoch(train_loader, epoch)
mean_epoch_test_loss = self._test_epoch(test_loader)
print('====> Epoch: {} Average train loss: {:.4f} (Test set loss: {:.4f})'.format(epoch, mean_epoch_loss,
mean_epoch_test_loss))
def _create_latent_map(self):
"""
Create saliency object for decoder-encoder map.
Parameters
----------
"""
self.latent_map = DecoderEncoder(self.model, use_residuals=self.use_residuals)
def _train_epoch(self, data_loader, epoch):
"""
Trains the model for one epoch.
Parameters
----------
data_loader: torch.utils.data.DataLoader
epoch: int
Epoch number
Return
------
mean_epoch_loss: float
"""
self.model.train()
n_data = data_loader.dataset.data.shape[0]
epoch_loss = 0.
for batch_idx, data in enumerate(data_loader):
iter_loss = self._train_iteration(data, n_data)
epoch_loss += iter_loss
if batch_idx % 10 == -1:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(data_loader.dataset),
100. * batch_idx / len(data_loader),
epoch_loss / (batch_idx+1)))
mean_epoch_loss = epoch_loss / (batch_idx + 1)
return mean_epoch_loss
def _train_iteration(self, data, n_data):
"""
Trains the model for one iteration on a batch of data.
Parameters
----------
data: torch.Tensor
A batch of data. Shape : (batch_size, channel, height, width).
"""
data = data.to(self.device)
recon_data, latent_dist, latent_sample = self.model(data)
latent_output = self.latent_map(latent_sample, data)
loss = self.loss_f(data, recon_data, latent_dist, latent_sample, n_data, latent_output)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
return loss.item()
def _test_epoch(self, data_loader):
"""
Tests the model for one epoch.
Parameters
----------
data_loader: torch.utils.data.DataLoader
epoch: int
Epoch number
Return
------
mean_epoch_loss: float
"""
self.model.eval()
n_data = data_loader.dataset.data.shape[0]
epoch_loss = 0.
for batch_idx, data in enumerate(data_loader):
data = data.to(self.device)
recon_data, latent_dist, latent_sample = self.model(data)
latent_output = self.latent_map(latent_sample, data)
loss = self.loss_f(data, recon_data, latent_dist, latent_sample, n_data, latent_output)
iter_loss = loss.item()
epoch_loss += iter_loss
mean_epoch_loss = epoch_loss / (batch_idx + 1)
return mean_epoch_loss
Classes
class Trainer (model, optimizer, loss_f, device=device(type='cpu'), use_residuals=True)
-
Class to handle training of model.
Expand source code
class Trainer(): """ Class to handle training of model. """ def __init__(self, model, optimizer, loss_f, device=torch.device("cpu"), use_residuals=True): self.device = device self.model = model.to(self.device) self.loss_f = loss_f self.optimizer = optimizer self.use_residuals = use_residuals self._create_latent_map() def __call__(self, train_loader, test_loader, epochs=10): """ Trains the model. Parameters ---------- data_loader: torch.utils.data.DataLoader epochs: int, optional Number of epochs to train the model for. checkpoint_every: int, optional Save a checkpoint of the trained model every n epoch. """ for epoch in range(epochs): mean_epoch_loss = self._train_epoch(train_loader, epoch) mean_epoch_test_loss = self._test_epoch(test_loader) print('====> Epoch: {} Average train loss: {:.4f} (Test set loss: {:.4f})'.format(epoch, mean_epoch_loss, mean_epoch_test_loss)) def _create_latent_map(self): """ Create saliency object for decoder-encoder map. Parameters ---------- """ self.latent_map = DecoderEncoder(self.model, use_residuals=self.use_residuals) def _train_epoch(self, data_loader, epoch): """ Trains the model for one epoch. Parameters ---------- data_loader: torch.utils.data.DataLoader epoch: int Epoch number Return ------ mean_epoch_loss: float """ self.model.train() n_data = data_loader.dataset.data.shape[0] epoch_loss = 0. for batch_idx, data in enumerate(data_loader): iter_loss = self._train_iteration(data, n_data) epoch_loss += iter_loss if batch_idx % 10 == -1: print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( epoch, batch_idx * len(data), len(data_loader.dataset), 100. * batch_idx / len(data_loader), epoch_loss / (batch_idx+1))) mean_epoch_loss = epoch_loss / (batch_idx + 1) return mean_epoch_loss def _train_iteration(self, data, n_data): """ Trains the model for one iteration on a batch of data. Parameters ---------- data: torch.Tensor A batch of data. Shape : (batch_size, channel, height, width). """ data = data.to(self.device) recon_data, latent_dist, latent_sample = self.model(data) latent_output = self.latent_map(latent_sample, data) loss = self.loss_f(data, recon_data, latent_dist, latent_sample, n_data, latent_output) self.optimizer.zero_grad() loss.backward() self.optimizer.step() return loss.item() def _test_epoch(self, data_loader): """ Tests the model for one epoch. Parameters ---------- data_loader: torch.utils.data.DataLoader epoch: int Epoch number Return ------ mean_epoch_loss: float """ self.model.eval() n_data = data_loader.dataset.data.shape[0] epoch_loss = 0. for batch_idx, data in enumerate(data_loader): data = data.to(self.device) recon_data, latent_dist, latent_sample = self.model(data) latent_output = self.latent_map(latent_sample, data) loss = self.loss_f(data, recon_data, latent_dist, latent_sample, n_data, latent_output) iter_loss = loss.item() epoch_loss += iter_loss mean_epoch_loss = epoch_loss / (batch_idx + 1) return mean_epoch_loss