1.2. disentanglement#

1.2.1. VAEs#

Some good disentangled VAE implementations are here and more general VAE implementations are here. Tensorflow implementations available here

The goal is to obtain a nice latent representation \(\mathbf z\) for our inputs \(\mathbf x\). To do this, we learn parameters \(\phi\) for the encoder \(p_\phi( \mathbf z\vert \mathbf x)\) and \(\theta\) for the decoder \(q_{\mathbf \theta} ( \mathbf x\vert \mathbf z)\). We do this with the standard vae setup, whereby a code \(z\) is sampled, using the output of the encoder (intro to VAEs here).

1.2.1.1. disentangled vae losses#

reconstruction loss

compactness prior loss

total correlation loss

encourages accurate reconstruction of the input

encourages points to be compactly placed in space

encourages latent variables to be independent

  • summarizing the losses

    • reconstruction loss - measures the quality of the reconstruction, the form of the loss changes based on the assumed distribution of the likelihood of each pixel

      • binary cross entropy loss - corresopnds to bernoulli distr., most common - doesn’t penalize (0.1, 0.2) and (0.4, 0.5) the same way, which might be problematic

      • mse loss - gaussian distr. - tends to focus on a fex pixels that are very wrong

      • l1 loss - laplace distr.

    • compactness prior loss

      • doesn’t use the extra injected latent noise

      • tries to push all the points to the same place

      • emphasises smoothness of z, using as few dimensions of z as possible, and the main axes of z to capture most of the data variability

      • usually assume prior is standard normal, resulting in pushing the code means to 0 and code variance to 1

      • we can again split this term \(\sum_i \underbrace{\text{KL} \left(p_\phi( \mathbf z_i\vert \mathbf x)\:\vert\vert\:prior(\mathbf z_i) \right)}_{\text{compactness prior loss}} = \underbrace{\sum_i I(x; z)}_{\text{mutual info}} + \underbrace{\text{KL} \left(p_\phi( \mathbf z_i)\:\vert\vert\:prior(\mathbf z_i) \right)}_{\text{factorial prior loss}}\)

    • total correlation loss - encourages factors to be independent

      • measures dependence between marginals of the latent vars

      • intractable (requires pass through the whole dset)

      • instead sample \(dec_\phi(\mathbf z\vert \mathbf x)\) and create \(\prod_j dec_\phi( \mathbf z_i\vert \mathbf x) \) by permuting across the batch dimension

        • now, calculate the kl with the density-ratio trick - train a classifier to approximate the ratio from these terms

1.2.1.2. disentangled vae in code#

## Reconstruction + KL divergence losses summed over all elements and batch
def loss_function(x_reconstructed, x, mu, logvar, beta=1):
  '''
  Params
  ------
  x_reconstructed: torch.Tensor
		Reconstructed input, with values between 0-1
	x: torch.Tensor
		input, values unrestricted
  '''
  
  ## reconstruction loss (assuming bernoulli distr.)
  ## BCE = sum_i [x_rec_i * log(x_i) + (1 - x_rec_i) * log(1-x_i)]
	rec_loss = F.binary_cross_entropy(x_reconstructed, x, reduction='sum')

  ## compactness prior loss
	## 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
	KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
  
  ## total correlation loss (calculate tc-vae way)
	z_sample = mu + torch.randn_like(exp(0.5 * logvar))
	log_pz, log_qz, log_prod_qzi, log_q_zCx = func(z_sample, mu, logvar)
  ## I[z;x] = KL[q(z,x)\vert\vertq(x)q(z)] = E_x[KL[q(z\vertx)\vert\vertq(z)]]
  mi_loss = (log_q_zCx - log_qz).mean()
  ## TC[z] = KL[q(z)\vert\vert\prod_i z_i]
  tc_loss = (log_qz - log_prod_qzi).mean()
  ## dw_kl_loss is KL[q(z)\vert\vertp(z)] instead of usual KL[q(z\vertx)\vert\vertp(z))]
  dw_kl_loss = (log_prod_qzi - log_pz).mean()
  
	return rec_loss + beta * KLD

1.2.1.3. vaes for interpretation#

  • icam (bass et al. 2020) - learn disentangled repr using vae with adv loss to make repr class-relevant

    • va-gan (baumgartner et al. 2018) - interpret features in GAN space

1.2.1.4. various vaes#

  • vae (kingma & welling, 2013)

  • beta-vae (higgins et al. 2017) - add hyperparameter \(\beta\) to weight the compactness prior term

  • beta-vae H (burgess et al. 2018) - add parameter \(C\) to control the contribution of the compactness prior term

    • \(\overbrace{\mathbb E_{p_\phi(\mathbf z\vert \mathbf x)}}^{\text{samples}} [ \underbrace{-\log q_{\mathbf \theta} ( \mathbf x\vert \mathbf z)}_{\text{reconstruction loss}} ] + \textcolor{teal}{\beta}\; \vert\sum_i \underbrace{\text{KL} \left(p_\phi( \mathbf z_i\vert \mathbf x)\:\vert\vert\:prior(\mathbf z_i) \right)}_{\text{compactness prior loss}} -C\vert\)

    • C is gradually increased from zero (allowing for a larger compactness prior loss) until good quality reconstruction is achieved

  • factor-vae (kim & minh 2018) - adds total correlation loss term

    • computes total correlation loss term using discriminator (can we discriminate between the samples when we shuffle over the batch dimension or not?)

    • beta-TC-VAE = beta-total-correlation VAE (chen et al. 2018) - same objective but computed without need for discriminator

      • use minibatch-weighted sampling to compute each of the 3 terms that make up the original VAE compactness prior loss

      • main idea is to better approximate \(q(z)\) by weighting samples appropriately - biased, but easier to compute

    • Interpretable VAEs for nonlinear group factor analysis

  • Wasserstein Auto-Encoders (tolstikhin et al.) - removes the mutual info part of the loss

    • wasserstein distance = earth-movers distance, how far apart are 2 distrs

    • minimizes wasserstein distance + penalty which is similar to auto-encoding penalty, without the mutual info term

    • another intuition: rather than map each point to a ball (since VAE adds noise to each latent repr), we only constraint the overall distr of Z, potentially making reconstructions less blurry (but potentially making latent space less smooth)

    • wae

  • Adversarial Latent Autoencoder (pidhorskyi et al. 2020)

    • improve quality of generated VAE reconstructions by using a different setup which allows for using a GAN loss

    • alae

  • Variational Autoencoders Pursue PCA Directions (by Accident)

    • local orthogonality of the embedding transformation

    • prior \(p(z)\) is standard normal, so encoder is assumed to be Gaussian with a certain mean, and diagonal covariance

    • disentanglement is sensitive to rotations of the latent embeddings but reconstruction err doesn’t care

    • for linear autoencoder w/ square-error as reconstruction loss, we recover PCA decomp

  • Disentangling Disentanglement in Variational Autoencoders (2019)

    • independence can be too simplistic, instead 2 things:

      • the latent encodings of data having an appropriate level of overlap

        • keeps encodings from just being a lookup table

        • when encoder is unimodal, \(I(x; z)\) gives us a good handle on this

      • prior structure on the latents (e.g. independence, sparsity)

    • to trade these off, can penalize divergence between \(q_\phi(z)\) and \(p(z)\)

    • nonisotropic priors - isotropic priors are only good up to rotation in the latent space

      • by chossing a nonisotropic prior (e.g. nonisotropic gaussian), can learn certain directions more easily

    • sparse prior - can help do clustering

  • VAE-SNE: a deep generative model for simultaneous dimensionality reduction and clustering (graving & couzin 2020) - reduce dims + cluster without specifying number of clusters

    • Screen Shot 2020-09-10 at 11.40.10 PM

    • stochastic neighbor regularizer that optimizes pairwise similarity kernels between original and latent distrs. to strengthen local neighborhood preservation

      • can use different neighbor kernels, e.g. t-SNE similarity (van der Maaten & Hinton, 2008) or Gaussian SNE kernel (Hinton & Roweis, 2003)

      • perplexity annealing technique (Kobak and Berens, 2019) - decay the size of local neighborhoods during training (helps to preserve structure at multiple scales)

    • Gaussian mixture prior for learning latent distr. (with very large number of clusters)

    • extensive evaluation - test several datasets / methods and evaluate how well the first 2 dimensions preserve the following:

      • global - correlation between pairwise distances in orig/latent spaces

      • local - both metric (distance- or radius-based) and topological (neighbor-based) neighborhoods which are 1% of total embedding size

      • fine-scale - neighborhoods which are <1% of total embedding size

      • temporal info (for time-series data only) - correlation between latent and original temporal derivatives

      • likelihood on out-of-sample data

    • further advancements

      • embed into polar coordinates (rather than Euclidean) helps a lot

      • convolutional VAE-SNE - extract features from images using some pre-trained net and then run VAE-SNE on these features

    • background: earlier works also used SNE objective for regularization - starts with van der Maaten 2009 (parametric t-SNE)

    • future work: density-preserving versions of t-SNE, modeling hierarchical structure in vae, conditional t-SNE kernel

  • A Survey of Inductive Biases for Factorial Representation-Learning (ridgeway 2016)

    • desiderata

      • compact

      • faithful - preserve info required for task

      • explicitly represent the attributes required for the task at hand

      • interpretable by humans

    • factorial representation - attributes are statistically independent and can provide a userful bias for learning

      • “compete” - factors are more orthogonal

      • “cooperate” - factors are more similar

      • bias on distribution of factors

        • PCA - minimize reconstruction err. subject to orthogonal weights

        • ICA - maximize non-Gaussianity (can also have sparse ICA)

      • bias on factors being invariant to certain types of changes

        • ISA (independent subspace analysis) - 2 layer model where first layer is linear, 2nd layer pools first layer (not maxpool, more like avgpool), sparsity at second layer

          • i.e. 1st layer cooperates, 2nd layer competes

        • VQ - vector quantizer - like ISA but first layer filters now compete and 2nd layer cooperates

        • SOM - encourages topographic map by enforcing nearby filters to be similar

      • bias in how factors are combined

        • linear combination - PCA/ICA

        • multilinear models - multiplicative interactions between factors (e.g on top of ISA)

        • functional parts - factor components are combined to construct the output

          • ex. NMF - parts can only add, not substract to total output

          • ex. have each pixel in the output be represented by only one factor in a VQ

        • hierarchical layers

          • ex. R-ICA - recursive ICA - run ICA on coefficients from previous layer (after some transformation)

    • supervision bias

      • constraints on some examples

        • e.g. some groups have same value for a factor

        • e.g. some examples have similar distances (basis for MDS = multidimensional scaling)

        • e.g. analogies between examples

        • can do all of these things with auto-encoders

  • more papers

1.2.2. GANs#

1.2.2.1. model-based (disentangle during training)#

1.2.2.2. post-hoc (disentangle after training)#

1.2.3. misc#

  • Learning Diverse and Discriminative Representations via the Principle of Maximal Coding Rate Reduction (yu, …, & ma, 2020)

    • goal: learn low-dimensional structure from high-dim (labeled or unlabeled) data

    • approach: instead of cross-entropy loss, use maximal coding rate reduction = MCR loss function to learn linear feature space where:

      • inter-class discriminative - features of samples from different classes/clusters are uncorrelated + different low-dim linear subspaces

      • intra-class compressible - features of samples from same class/cluster are correlated (i.e. belong to low-dim linear subspace)

      • maximally diverse - dimension (or variance) of features for each class/cluster should be as large as possible as long as uncorrelated from other classes/clusters

    • related to nonlinear generalized PCA

    • given random variable \(z\) and precision \(\epsilon\), rate distortion \(R(z, \epsilon)\) is minimal number of bits to encode \(z\) such that expected decoding err is less than \(\epsilon\)

      • can compute from finite samples

      • can compute for each class (diagonal matrices represent class/cluster membership in loss function)

      • MCR maximizes (rate distortion for all features) - (rate distortion for all data separated into classes)

        • like a generalization of information gain

    • evaluation

      • with label corruption performs better

  • Learned Equivariant Rendering without Transformation Supervision - separate foreground / background using video

1.2.4. (semi)-supervised disentanglement#

these papers use some form of supervision for the latent space when disentangling

1.2.5. evaluating disentanglement#

  • Challenging Common Assumptions in the Unsupervised Learning of Disentangled Representations (locatello et al. 2019)

    • state of disentanglement is very poor…depends a lot on architecture/hyperparameters

    • good way to evaluate: make explicit inductive biases, investigate benefits of this disentanglement

    • defining disentanglement - compact, interpretable, independent, helpful for downstream tasks, causal inference

      • a change in one factor of variation should lead to a change in a single factor in the learned repr.

    • unsupervised learning of disentangled reprs. is impossible without inductive biases

  • note - vae’s come with reconstruction loss + compactness prior loss which can be looked at on their own

  • data

  • beta-vae disentanglement metric score = higgins metric - see if we can capture known disentangled repr. using pairs of things where only one thing changes

    • start with a known generative model that has an observed set of independent and interpretable factors (e.g. scale, color, etc.) that can be used to simulate data.

    • create a dataset comprised of pairs of generated data for which a single factor is held constant (e.g. a pair of images which have objects with the same color).

    • use the inference network to map each pair of images to a pair of latent variables.

    • train a linear classifier to predict which interpretable factor was held constant based on the latent representations. The accuracy of this predictor is the disentanglement metric score.

  • Evaluating Disentangled Representations (sepliarskaia et al. 2019)

    • defn 1 (Higgins et al., 2017; Kim and Mnih, 2018; Eastwood and Williams, 2018) = factorVAE metric: A disentangled representation is a representation where a change in one latent dimension corresponds to a change in one generative factor while being relatively invariant to changes in other generative factors.

    • defn 2 (Locatello et al., 2018; Kumar et al., 2017): A disentangled representation is a representation where a change in a single generative factor leads to a change in a single factor in the learned representation.

    • metrics

      • DCI: Eastwood and Williams (2018) - informativeness based on predicting gt factors using latent factors

      • SAP: Kumar et al. (2017) - how much does top latent factor match gt more than 2nd latent factor

      • mutual info gap MIG: Chen et al. 2018 - mutual info to compute the same thing

      • modularity (ridgeway & mozer, 2018) - if each dimension of r(x) depends on at most a factor of variation using their mutual info

1.2.6. non-deep methods#

  • unifying vae and nonlinear ica (khemakhem et al. 2020)

    • ICA

      • maximize non-gaussianity of \(z\) - use kurtosis, negentropy

      • minimize mutual info between components of \(z\) - use KL, max entropyd