Module src.dsets.gaussian_mixture.dset
Expand source code
import numpy as np
import torch
from torch import nn
# DATASET and DATALOADER
class myDataset(torch.utils.data.Dataset):
def __init__(self, data, transform=None):
self.data = data
self.transform = transform
def __len__(self):
return len(self.data)
def __getitem__(self, index):
data = self.data[index]
if self.transform:
data = self.transform(data)
return data
def get_dataloaders(n_samples_per_cluster,
latent_means,
latent_vars,
extra_dim=8,
var=0.01,
batch_size=100,
shuffle=True,
return_latents=False):
"""A generic data loader
"""
latent_samples = generate_latent_samples(n_samples_per_cluster, latent_means, latent_vars)
data = generate_full_samples(latent_samples, extra_dim, var)
kwargs = {'num_workers': 1, 'pin_memory': True} if torch.cuda.is_available() else {}
data_loader = torch.utils.data.DataLoader(myDataset(data), batch_size=batch_size, shuffle=shuffle, **kwargs)
if return_latents:
return data_loader, latent_samples
else:
return data_loader
def samples(mu, var, nb_samples=500):
"""
Return a tensor of (nb_samples, features), sampled
from the parameterized gaussian.
:param mu: torch.Tensor of the means
:param var: torch.Tensor of variances (NOTE: zero covars.)
"""
out = []
for i in range(nb_samples):
out += [
torch.normal(mu, var.sqrt())
]
return torch.stack(out, dim=0)
def generate_latent_samples(nb_samples, latent_means, latent_vars):
latent_data = []
n_clusters = len(latent_means)
for i in range(n_clusters):
cluster = samples(
torch.Tensor(latent_means[i]),
torch.Tensor(latent_vars[i]),
nb_samples=nb_samples
)
latent_data.append(cluster)
return torch.cat(latent_data, dim=0)
def generate_full_samples(latent_samples, extra_dim, var):
out = []
nb_samples, dim = latent_samples.shape
for i in range(nb_samples):
zero = torch.zeros(extra_dim)
mu = torch.cat([latent_samples[i], zero])
v = var * torch.ones(dim + extra_dim)
out += [
torch.normal(mu, v.sqrt())
]
return torch.stack(out, dim=0)
Functions
def generate_full_samples(latent_samples, extra_dim, var)
-
Expand source code
def generate_full_samples(latent_samples, extra_dim, var): out = [] nb_samples, dim = latent_samples.shape for i in range(nb_samples): zero = torch.zeros(extra_dim) mu = torch.cat([latent_samples[i], zero]) v = var * torch.ones(dim + extra_dim) out += [ torch.normal(mu, v.sqrt()) ] return torch.stack(out, dim=0)
def generate_latent_samples(nb_samples, latent_means, latent_vars)
-
Expand source code
def generate_latent_samples(nb_samples, latent_means, latent_vars): latent_data = [] n_clusters = len(latent_means) for i in range(n_clusters): cluster = samples( torch.Tensor(latent_means[i]), torch.Tensor(latent_vars[i]), nb_samples=nb_samples ) latent_data.append(cluster) return torch.cat(latent_data, dim=0)
def get_dataloaders(n_samples_per_cluster, latent_means, latent_vars, extra_dim=8, var=0.01, batch_size=100, shuffle=True, return_latents=False)
-
A generic data loader
Expand source code
def get_dataloaders(n_samples_per_cluster, latent_means, latent_vars, extra_dim=8, var=0.01, batch_size=100, shuffle=True, return_latents=False): """A generic data loader """ latent_samples = generate_latent_samples(n_samples_per_cluster, latent_means, latent_vars) data = generate_full_samples(latent_samples, extra_dim, var) kwargs = {'num_workers': 1, 'pin_memory': True} if torch.cuda.is_available() else {} data_loader = torch.utils.data.DataLoader(myDataset(data), batch_size=batch_size, shuffle=shuffle, **kwargs) if return_latents: return data_loader, latent_samples else: return data_loader
def samples(mu, var, nb_samples=500)
-
Return a tensor of (nb_samples, features), sampled from the parameterized gaussian. :param mu: torch.Tensor of the means :param var: torch.Tensor of variances (NOTE: zero covars.)
Expand source code
def samples(mu, var, nb_samples=500): """ Return a tensor of (nb_samples, features), sampled from the parameterized gaussian. :param mu: torch.Tensor of the means :param var: torch.Tensor of variances (NOTE: zero covars.) """ out = [] for i in range(nb_samples): out += [ torch.normal(mu, var.sqrt()) ] return torch.stack(out, dim=0)
Classes
class myDataset (data, transform=None)
-
An abstract class representing a :class:
Dataset
.All datasets that represent a map from keys to data samples should subclass it. All subclasses should overrite :meth:
__getitem__
, supporting fetching a data sample for a given key. Subclasses could also optionally overwrite :meth:__len__
, which is expected to return the size of the dataset by many :class:~torch.utils.data.Sampler
implementations and the default options of :class:~torch.utils.data.DataLoader
.Note
:class:
~torch.utils.data.DataLoader
by default constructs a index sampler that yields integral indices. To make it work with a map-style dataset with non-integral indices/keys, a custom sampler must be provided.Expand source code
class myDataset(torch.utils.data.Dataset): def __init__(self, data, transform=None): self.data = data self.transform = transform def __len__(self): return len(self.data) def __getitem__(self, index): data = self.data[index] if self.transform: data = self.transform(data) return data
Ancestors
- torch.utils.data.dataset.Dataset