Helping VAE by Regularizing Latent Vector

Min Jean Cho

In this article, I’ll suggest an approach that could improve the performance of variational autoencoder (VAE). If you’re unfamiliar with VAE, please refer to this article.

Some background story before I get started: I’ve had worked on VAE for molecule generation task using ZINC and QM9 which are benchmark datasets frequently used for drug discovery. ZINC is a collection of commercially available drug-like molecules, and QM9 is a collection of organic molecules comprising up to nine non-hydrogen atoms. I used simple VAE architecture with the encoder as four convolutional layers and the decoder as four fully connected layers. After training, the high dimensional latent space of each dataset was projected into two dimensional space (I will describe the projection method in another article).

Figure 1. Two-dimensional visualization of the latent space (\(n_{dim} = 16\)) of QM9 dataset. The latent vectors are shown as molecules encoded by them.

Figure 2. Two-dimensional visualization of the latent space (\(n_{dim} = 256\)) of ZINC dataset. The latent vectors are shown as molecules encoded by them.

Both latent spaces (Figure 1 and Figure 2) inferred by VAE are continuous, i.e., neighboring latent vectors encode structurally similar molecules. As one can observe by comparing Figure 1 and Figure 2, albeit the molecules of ZINC dataset being more complex than those of QM9 dataset, the latent space of ZINC dataset appear much simpler than the latent space of QM9 dataset; as shown in Figure 2, VAE trained with ZINC dataset generated molecules that are mostly simple carbon chains. However, both reconstruction loss and Kullback-Leibler dievrgence (KLD) loss during training were very low. So what went wrong for VAE trained with ZINC dataset?

KLD Loss vs Reconstruction Loss

The encoder of VAE aims to find a distribution \(Q(\mathbf{z}|\mathbf{x})\) that is similar to the posterior distribution \(P(\mathbf{z}|\mathbf{x})\), and the distribution \(Q(\mathbf{z}|\mathbf{x})\) defines the latent space.

\[\mathbf{z} \sim Q(\mathbf{z}|\mathbf{x}) \approx P(\mathbf{z}|\mathbf{x}) = \frac{P(\mathbf{z}|\mathbf{x})P(\mathbf{z})}{P(\mathbf{x})}\]

The loss function of VAE consists of two terms, KLD loss and reconstruction loss.

\[L_{VAE} = L_{KLD} + L_{Reconstruction}\]

\[L_{KLD} = D_{KL}[Q(\mathbf{z}|\mathbf{x})||P(\mathbf{z})] = -\frac{1}{2}\sum_{k=1}^p \left[\text{log}\sigma_{Q,k}^2 + 1 - \sigma_{Q,k}^2 - \mu_{Q,k}^2 \right]\]

\[L_{Reconstruction} = E_Q\left[\text{log}P(\mathbf{x}|\mathbf{z})\right]\]

The KLD term tries to make the latent space (posterior) close to the prior, multivariate normal distribution \(N(\mathbf{0},\mathbf{I})\) (\(D_{KL} \rightarrow 0\) as \(\mu_{k} \rightarrow 0\) and \(\sigma_{k}^2 \rightarrow 1\)). But as the KLD loss gets very small (i.e., latent space gets very close to multivariate normal distribution), the reconstruction loss increases because gathering regions occupied by different \(\mathbf{z}_i\) to the center of the latent space (\(\mathbf{\mu}_Q = \mathbf{0}\)) hinders reconstruction (although it improves interpolation). The reconstruction loss term prevents such dense overlapping. Thus, the two terms balance each other. We can think of a region occupied by each \(\mathbf{z}_i\) as a bubble. KLD loss makes the bubbles get gathered and reconstruction loss makes the bubbles get dispersed.

Figure 3. Schematic diagrams showing the distribution of latent vectors (as bubbles) in latent space.

Middle of Figure 3 shows the latent space when KLD term and reconstruction term is well balanced. Right panel shows that the latent space is too close to \(N(\mathbf{0}, \mathbf{I})\) when the KLD term prevails - the bubbles overlap so reconstruction will not be good. Also note that \(N(\mathbf{0}, \mathbf{I})\) is the prior not the posterior. Left panel shows that the bubbles are dispersed and there are valleys (discontinuity) when the reconstruction term prevails; due to the valleys, interpolation will not be good (e.g., invalid molecules during inference time) - if sampled \(\mathbf{z}\) is located in a valley then \(\mathbf{\tilde{x}}\) will not be generated properly.

Now, think of reconstruction loss for latent vector \(\mathbf{z}_i\) of molecule \(i\) as a point in a bubble during training. At epoch \(t>0\), decoder reconstructs molecule \(i\) using learned weights, i.e., the reconstruction of molecule \(i\) at epoch \(t\) is an intra-bubble interpolation or extrapolation.

Figure 4. Intra-bubble interpolation and extrapolation.

Here’s what could have happened. I said earlier that both the reconstruction and KLD losses were very low and that the majority of molecules of the ZINC dataset are complex while there are relatively simpler molecules within the ZINC dataset as well. The VAE could make smaller bubbles for more complex molecules and larger bubbles for simpler molecules in order to reduce the total loss on average. The location and the size of the bubble are determined as mean and log variance during reparametrization; the smaller the variance, the closer to the center. Thus, small bubbles for complex molecules might tend to locate near the center of latent space and large bubbles for simple molecules might tend to locate outside of the small bubbles.

Figure 5. Bubbles encoding data points with different complexities in a latent space.

Such pattern of bubble distribution in a latent space manifests when the complexities of data points to be encoded varies considerably as in the case of ZINC molecules. When the latent space has such pattern, latent space (inter-bubble) interpolation may not be satisfactory due to sudden changes near the center of latent space (dense region). We can observe similar pattern for MNIST digit images. The following figure visualizes the latent space of MNIST digit images after training VAE.

Figure 6. Visualization of the latent space generated by VAE of MNIST digit images.

In Figure 6, one can observe two very large bubbles for relatively simple numbers 0 and 1. How can we make the latent bubbles more evenly distributed (while using \(N(\mathbf{0}|\mathbf{I})\) as a prior distribution)? I suggest two tricks. One trick is to apply tanh activation to the encoder’s linear layer that outputs the mean of digit \(i\) (\(\mu_{i}\)) such that \(\mu_{i,k}\) ranges from \(-3\) to \(+3\), i.e., \(\mu_i = 3 \times \text{tanh}(\mathbf{x}_i)\). This would enforce all bubbles to locate inside the \(99^{th}\) percentile of the prior distribution. By setting the boundaries of \(\mathbf{\mu}_i\) for all digits, bubbles that had been buried in the stack of small bubbles appear more evenly.

Figure 7. The latent space of MNIST digit images when \(\mathbf{\mu_i}\) is bounded.

When the mean locations of digits are bounded using the \(3 \times \text{tanh}\) activation (\(\mu_i = 3 \times \text{tanh}(\mathbf{x_i)})\)), digit 7 appeared (compare with Figure 6) while the size of other bubbles (including the bubble for digit 1) decreased.

The second trick is to set limits of the log variance such that bubbles cannot be larger or smaller than a particular size. This is a more direct way to enforce the bubbles encoding simple objects to not be too large and to enforce the bubbles encoding complex objects to be too small. I experimented with this trick by applying tanh activation to the encoder’s linear layer that outputs log variance such that \(\text{log}\mathbf{\sigma}_i^2\) ranges from \(-1\) to \(+1\), i.e., \(\text{log}\mathbf{\sigma}_i^2 = \text{tanh}(\mathbf{x}_i)\).

Figure 8. The latent space of MNIST digit images when \(\text{log}\mathbf{\sigma}_i^2\) is bounded.

As shown in Figure 8, the sizes of bubbles encoding digits \(0\) and \(1\) have decreased. But the bubble encoding digit \(2\) disappeared, for which the reason was unclear. Then, I bounded both \(\mathbf{\mu}_i\) and \(\text{log}\mathbf{\sigma}_i^2\).

Figure 9. The latent space of MNIST digit images when both \(\mathbf{\mu}_i\) and \(\text{log}\mathbf{\sigma}_i^2\) are bounded.

The distribution of bubbles in the latent space improved by bounding both \(\mathbf{\mu}_i\) and \(\mathbf{\sigma}_i^2\). To further improve the the evenness of latent space, it seems that boundaries of \(\mathbf{\mu}_i\) and \(\mathbf{\sigma}_i^2\) need to be carefully tuned. This can be done by setting \(\mathbf{\mu}_i = m \times \text{tanh}(\mathbf{x}_i)\) and \(\mathbf{\sigma}_i^2 = s \times \text{tanh}(\mathbf{x}_i) - v\), where \(m, s,\) and \(v\) are hyperparameters.

The two tricks suggested in this article were experimented on MNIST digit dataset for simplicity. Generative models can sometimes be tricky to train; if you are stuck, feel free to try these two tricks out as implementing them are extremely simple.