Introducing dynamical constraints into representation learning

Dedi Wang Biophysics Program and Institute for Physical Science and Technology, University of Maryland, College Park 20742, USA. Yihang Wang Luke Evans Pratyush Tiwary
Abstract

While representation learning has been central to the rise of machine learning and artificial intelligence, a key problem remains in making the learnt representations meaningful. For this the typical approach is to regularize the learned representation through prior probability distributions. However such priors are usually unavailable or ad hoc. To deal with this, we propose a dynamics-constrained representation learning framework. Instead of using predefined probabilities, we restrict the latent representation to follow specific dynamics, which is a more natural constraint for representation learning in dynamical systems. Our belief stems from a fundamental observation in physics that though different systems can have different marginalized probability distributions, they typically obey the same dynamics, such as Newton’s and Schrödinger’s equations. We validate our framework for different systems including a real-world fluorescent DNA movie dataset. We show that our algorithm can uniquely identify an uncorrelated, isometric and meaningful latent representation.

Introduction

The ability to learn meaningful and useful representations of data is a key challenge towards applying machine learning (ML) and Artificial Intelligence (AI) to various real world problems. On one hand, a useful representation extracts and organizes the discriminative information from data to support effective machine learning for downstream tasks. On the other hand, a meaningful representation is often more interpretable, which is fundamental to help humans trust AI. Recently, a variety of representation learning methods have been proposed based on the idea of autoencoding – learning a mapping from high dimensional data to a low dimensional representation or latent space which is able to approximately reconstruct the original data.[1] However, while the ability to reconstruct the original data might make a representation useful, it does not necessarily make it meaningful. The traditional and popular way to tackle this problem has been to introduce some generic priors on the representation. In a seminal paper on representation learning Bengio et al. proposed a set of so-called meta-priors, including attributes such as hierarchical organization, disentanglement of explanatory factors, concentration of data along low dimensional manifolds, clusterability, as well as temporal and spatial coherence.[2] These meta-priors are believed as ubiquitous in meaningful representation learning as the postulates and laws of physics.

In this work, we emphasize the importance of introducing physics-based dynamics laws into representation learning that avoids the use of hand-tuning the form of the meta-priors. Instead, we find that attributes normally associated with meaningful representations naturally and directly arise from our framework, which we call Dynamics Constrained Auto-encoder (DynamicsAE). To our knowledge, this is the first representation learning algorithm which regularizes the latent space purely based on its dynamical properties. For the widely adopted model of Brownian or overdamped Langevin dynamics,[3] such a dynamical constraint makes the proposed model able to recover the ground-truth latent variables up to an isometry, which has been theoretically proved in Ref. [4]. In this work, we focus specifically on introducing priors corresponding to overdamped Langevin dynamics, but the framework should be easily generalized to other physics-based dynamics models. We validate our algorithm through a variety of stochastic models and datasets, including a fluorescent DNA movie dataset. For these diverse examples, we demonstrate unequivocally that the DynamicsAE can successfully identify the true latent factors, while other competing approaches such as -VAE[5] fail to do so.

Results

General introduction to autoencoder-based representation learning

The goal of autoencoder-based representation learning algorithms in general is to learn a low-dimensional latent representation such that the original high-dimensional observations can be reconstructed from the low-dimensional representation. Most of the recent advances in representation learning are built upon the pioneering work of the Variational Autoencoder (VAE).[1, 6] These algorithms consist of at least three important components: an encoder , a decoder , and a prior distribution for the latent representation space. The encoder maps the input to the representation space , while the decoder reconstructs the input data from the representation. In addition, the encoder is encouraged to satisfy some structure specified by the prior distribution on the latent space, where denotes trainable parameters for the prior. The objective function of VAE, where the subscript rep signifies the representation learning process, can be written as follows:

(1)

where is the empirical data distribution, and is the number of data points. The stochastic encoder and the stochastic decoder are usually parameterized by neural networks with the model parameters and . This objective function can be divided into two parts: The first term measures the ability of our representation to reconstruct the input, while the second term can be interpreted as the complexity penalty that acts as a regulariser. Such a trade-off between the prediction capacity and model complexity can then be controlled by a hyper-parameter . Compared to the original VAE objective where , a modified version named -VAE shows a better performance in terms of disentanglement with an appropriately tuned .[5]

Learning the latent representation with dynamic constraints

Minimizing the objective in Eq. 1 requires a definition of the prior distribution . However, such a predefined distribution is usually unavailable a priori. Choosing a simplistic prior like the standard Gaussian distribution as done in VAE could lead to over-regularization, resulting in poor latent representations. Therefore, several flexible prior distributions have been proposed, such as a mixture of Gaussians prior,[7] VampPrior,[8] and normalizing flows.[9] However, if the prior distribution is too flexible, it might defeat its purpose and no longer serve as a regularizer for the learned latent space.

To deal with this dilemma, here we propose the dynamics constrained representation learning framework summarized in Fig. 1. Instead of using a predefined prior distribution to regularize the latent space, we enforce the latent representation to follow a specific yet generic class of dynamics. Our inspiration here stems from physics, where even though different systems can and likely will have different probability distribution along the same low-dimensional projection , they will obey the same dynamics equations, such as the Newton’s second law in classical mechanics and Schrödinger’s equation in quantum mechanics. Therefore, we believe this is a more natural and generic constraint for representation learning.

Figure 1: a. Network architecture used for dynamics constrained framework. The encoder , the decoder , the diffusion and the force field are nonlinear deep neural networks. b. A flowchart illustrating the dynamics constrained framework.

For Markovian processes, their dynamics can be described by their transition density with . For instance, for deterministic dynamical processes, the transition density would equal a delta function. A tempting idea then is to match the encoded transition density to some specific prior transition density . This can be done by minimizing the sliced-Wasserstein distance between these two conditional probability distributions if enough samples from each distribution are available (Methods).[10, 11] Samples from the encoded transition density can be obtained, for instance, by performing swarms of short trajectories starting from each point . In many cases, however, these burst trajectories are usually unavailable. Therefore, it would be desirable if this approach could also be applied to a single long trajectory. Here we approximate the sliced-Wasserstein distance by discretizing the latent space into bins . In other words, instead of forcing to follow some specific distribution for every point , we encourage within each bin to follow some specific distribution. When the number of bins is large enough, this should provide a reasonable approximation. The discretization and resampling schemes are discussed in Methods. With this we can propose a dynamics constrained regularizer :

(2)

where and is a prior transition probability which is related to the dynamics of the latent representation. In this way, even if an accurate prior distribution is unavailable, we can still regularize the latent space by forcing its dynamics to follow some specific form of the prior transition probability .

Similarly, the reconstruction loss in Eq. 1 can also be reformulated in terms of bins and generalized to learning dynamics. For simplicity, we only consider in this paper the case of mappings where the latent variable is a deterministic function of the input data given as , but the algorithm should be easily generalized to stochastic encoders. Let denote the deterministic decoder that maps the latent variable back to the original data space. In this way, the reconstruction loss can be replaced by the -loss, namely:

(3)

Combining the reconstruction loss in Eq. 3 and the regularizer in Eq. 13, we can tune the hyper-parameter to force the latent space to follow specific dynamics given the prior transition density . In the next subsection, we will discuss how to choose such a prior transition density.

Learning the dynamics prior through Langevin dynamics

We focus in this paper on stochastic systems governed by the overdamped Langevin equation or Brownian dynamics as the underlying dynamics, but the framework should be more generally applicable to other dynamical systems. Consider a system where the representation obeys Brownian dynamics,

(4)

where is the force field experienced along , is the diffusion matrix, and is a Brownian motion in . Further, the force field can also be given by where we have introduced the free energy . If we assume the diffusion matrix in the learned representation space is diagonal, Eq. 4 can be simplified as:

(5)

where . This assumption of diagonal diffusion matrix allows us to obtain a set of uncorrelated latent variables. For simplicity, in further derivations we absorb the time unit into the diffusion matrix and set . The prior transition density of the system parameterized by can then be obtained:

(6)

where both the force field and the diffusion matrix are nonlinear functions of the representation parameterized by neural network parameters . Given the pairwise samples from each bin , both the force field and the diffusion matrix can be inferred through likelihood maximization or equivalently by minimizing the following prior loss:

(7)

Once the estimated force field and the diffusion matrix are obtained, we can easily draw samples from the prior transition density using Eq. 5, which can then be used to regularize the representation learnt through Eq. 13.

Figure 2: Recovering the underlying dynamics from the transformed three well model system. The original simulated data (a,b), the transformed data (c,d), and the latent representation learned by our algorithm (e,f) are shown in the figure. The black arrows represent the force field (left) while the ellipses represent the diffusion field (right). The ellipses in d are highly stretched due to the extremely anisotropic and inhomogeneous diffusion field caused by the nonlinear mapping function.

In fact, given this generalizable framework, we can introduce even more constraints into the prior. For instance, in our experiments we observed that enforcing a constant diffusion largely stabilized the optimization process of our algorithm and avoided the algorithm learning trivial scaling. Namely, we incorporate an inductive bias into the learning process and regularize the latent representation by directly using the samples generated from the simplified Langevin dynamics obtained after using in Eq. 5:

(8)

By learning a latent space with isotropic and homogenous diffusion as in Eq. 8, we can aim to preserve some underlying geometric properties of the data as we show later through numerical examples.

Figure 3: Learning the underlying dynamics from the thirty-dimensional coordinates of alanine dipeptide in water. (a) Structure of alanine dipeptide. The main coordinates describing slow transitions are the torsion angles (---) and (---), but the neural network input is only the Cartesian coordinates of the heavy atoms. (b) Free energy surface of alanine dipeptide in water at 300K along the dihedral angles and . (c,d) show the latent representation learned by our algorithm. The black arrows represent the force field (c) while the ellipses represent the diffusion field (d). (e,f) illustrates the relationship between our learned latent representation and the ground-truth latent factors and .

Putting it all together

To sum up, the whole algorithm is aiming to optimize the following two objective functions in Eqs. 9 and 11:

(9)

where,

(10)

and,

(11)

The workflow so constructed can be summarized through Fig. 1 and Extended Alg. 1. Our algorithm iteratively discretizes the latent space into bins, and resamples the training dataset to focus on poorly sampled regions (see Methods). Then the representation learning is implemented via a two-step optimization process: The first step fixes the prior transition density , and concentrates on learning a better representation by minimizing the representation loss (Eq. 9); the second step fixes the latent representation and focuses on inferring the transition density through optimizing the prior loss (Eq. 11), which in turn provides the information to guide the representation learning process.

Model potential

We now first demonstrate the utility of our method using data obtained by simulating a 2D three-well model potential. The simulated data is warped with nonlinear mapping functions to mimic distortion from real-world observations before feeding into our algorithm (Supplementary Information or SI). The simulated data and its transformation are shown in Fig. 2. In the original data space -, the diffusion is constant (Fig. 2(b)). But after the transformation, the diffusion varies dramatically with different values as shown in Fig. 2(d). Feeding this transformed trajectory data into our algorithm, we find our algorithm can successfully recover the original variables up to an invertible linear transformation, wherein the diffusion is once again isotropic and homogeneous. The results also demonstrate the ability of our algorithm to recover the underlying kinetics even in the rarely sampled regions, such as the transition state in the three well model system.

Atomic resolution biomolecule in water

As a slightly more complex example, we further illustrate the power of our algorithm in the well-studied alanine dipeptide molecule with TIP3P water molecules. We obtain a molecular dynamics trajectory for this system (see SI for details). The backbone torsion angles and are known to be the most important reaction coordinates separating different metastable states of alanine dipeptide in water as shown in Fig. 3(a-b). We completely ignore prior knowledge of these dihedrals, and to make the task more challenging, we directly work with a 30-dimensional space comprising the three-dimensional coordinates of the 10 heavy atoms as input. To avoid issue due to trivial rotation and translation of the whole system, we align all the configurations to the first frame of the molecular dynamics trajectory. Fig. 3(c-d) illustrate that our algorithm is still able to recover the underlying dynamics even with a 2D Euclidean latent space. Even more interestingly, a clear correspondence between our learned latent representation and the torsion angles and is found, as shown in Fig. 3(e-f).This suggests that DynamicsAE can be used to recover physically meaningful representations such as the dihedral angles in alanine dipeptide.

Figure 4: Comparison of the results on dSprites dataset with two generative factors (X-Position, Y-Position) obtained from -VAE and dynamics-AE. a. Reconstructions of dSprites dataset (top: -VAE, bottom: dynamics-AE). First row: originals. Second row: reconstructions. Remaining two rows: reconstructions of latent traversals across each latent dimension. b. Recovering the latent representation from the sequential 2D dSprites images. The top row represents the latent representation learned from -VAE. The bottom row represents the latent space learned from dynamics-AE.

dSprites dataset

We then examine our algorithm in a 2D image dataset (dSprites) with known ground-truth generative factors. Only two available transformations (X-Position, Y-Position) of the dSprites dataset are first considered for better visualization purpose. To generate a long sequence of images, we take a random initial image, and sequentially transform it according to a Gaussian random walk along the two available transformations (X-Position, Y-Position). Further details of data generation can be found in SI. As shown in Fig. 4, we directly compare our algorithm with -VAE. Fig. 4(a) shows that both algorithms are capable of accurately reconstructing the input images and disentangling the two ground-truth factors x-position and y-position. However, as the ground-truth factors should follow a uniform distribution by construction, it is evident that only our algorithm uniquely identifies the ground truth factors up to a linear transformation in this case. This can be clearly seen by comparing the two free energy surfaces in Fig. 4(b) where the latent representation learned by -VAE follows a Gaussian distribution while the latent representation learned by our algorithm follows a uniform distribution. This simple but important result is not surprising. -VAE cannot recover the ground-truth due to the false prior assumption of a Gaussian distribution in the latent space. Therefore, instead of constraining the distribution of latent space to a prior distribution, regularizing the latent space via its dynamics exhibits huge advantages.

We push our algorithm further by learning a higher dimensional latent representation. In this case, three transformations (Scale, X-Position, Y-Position) of the dSprites dataset are considered. The results are presented in Fig. 5. We find that our algorithm can still successfully recover the underlying "ground-truth" factors (Scale, X-Position, Y-Position), which -VAE fails to identify. While it clearly outperforms -VAE, our algorithm however doesn’t entirely decouple the scale from the x and y positions. This is because there are only 6 distinct values along the Scale factor, where the boundary effect can no longer be ignored. This finally results in a curved manifold as shown in the right column of Fig. 5(b).

Figure 5: Comparison of the results on dSprites dataset obtained from -VAE and dynamics-AE with three transformations (Scale, X-Position, Y-Position) of the dataset. a. Reconstructions of dSprites dataset (top: -VAE, bottom: dynamics-AE). First row: originals. Second row: reconstructions. Remaining three rows: reconstructions of latent traversals across each latent dimension. b. Recovering the latent representation from the sequential dSprites images. The left column represents the latent space learned from -VAE. The right column represents the latent representation learned from dynamics-AE.
Figure 6: Comparison of the results on DNA dataset obtained from -VAE and dynamics-AE. a. Reconstructions of the DNA molecule position (top: -VAE, bottom: dynamics-AE). First row: originals. Second row: reconstructions. Remaining two rows: reconstructions of latent traversals across each latent dimension. b. Recovering the latent representation from the fluorescent DNA molecule. The top row represents the latent representation learned from -VAE. The bottom row represents the latent space learned from dynamics-AE.

Fluorescent DNA

We finally apply our algorithm to a real world video dataset consisting of the Brownian motion of DNA molecules in solution as described in the work of Ref. [12]. Following Ref. [4], we use the same method to obtain the “ground-truth" latent variables (the and coordinates of the center of DNA molecule). The details of the preprocessing protocol are provided in SI. As shown in Fig. 6(a), both algorithms can reconstruct the positions of the molecules with a 2D latent space. However, even with a large , -VAE fails to disentangle the and positions in this case. In comparison, our algorithm can still successfully disentangle the and positions. Finally, Fig. 6(b) clearly illustrates that -VAE can only learn a complex nonlinear warping of the ground-truth latent variables, while our algorithm can still uniquely recover the ground-truth up to an isometry.

Discussion

In this work, we have proposed a purely dynamics-constrained representation learning framework, and demonstrated its power through a number of numerical examples. We believe that constraining the latent dynamics is a more natural way to regularize the latent space. Especially given that the accurate prior probability distribution is usually unavailable, constraining the latent dynamics may also arguably be the only generic way to regularize the latent representation.

There are a variety of additional avenues for such a dynamics-based approach. Through the numerical examples, our results further support the previous findings that constraining the latent dynamics is enough to uniquely identify the latent representation up to an isometry.[4, 13] More specifically in overdamped Langevin dynamics, we find our algorithm successfully decorrelates the latent representation by assuming a diagonal diffusion matrix. This in fact provides a promising way for dynamics-based disentanglement in representation learning. Besides, our algorithm can also preserve the intrinsic geometry by enforcing a constant homogeneous diffusion field in the latent space. Finally, as only dynamical information matters in the learning process, some more flexible spatial and temporal sampling strategies can be used to make full use of the dataset. In this work, we conduct a preliminary exploration of the sampling scheme by resampling the dataset based on the learned latent representation. Such a sampling strategy allows the algorithm to focus more on poorly sampled regions, greatly improving recovery of the underlying dynamic structure.

This framework provides a completely dynamical viewpoint for representation learning, which we hope can be a new starting point for combining physics and machine learning: dynamical models are an essential component of physics, while representation learning also play a crucial role in machine learning; therefore, introducing physics-based dynamical models into representation learning provides a huge opportunity for both communities. On one hand, this framework provides a promising way to infer effective dynamical models for complex systems. On the other hand, incorporating familiar concepts from physics will also benefit the interpretability of machine learning. In this work, we focused on the overdamped Langevin dynamics, but the framework should be easily generalized to other dynamic systems, such as the general Langevin dynamics, the deterministic Hamiltonian dynamics and the Schrödinger equation. More physical concepts other than the diffusion and free energy can also be introduced like the Hamiltonian, position, momentum, angular momentum and related. With these concepts, some conservation laws and the associated symmetries can then be applied to the models. In other directions, a straightforward extension can be to extend the framework from the autoencoder to the information bottleneck, wherein the latent variables are required only to predict some target instead of reconstructing the original input data .[14, 15] The choice of the target can be domain-dependent, making the framework highly flexible.[16, 17] A neural network based likelihood maximization method is used to infer the force field and diffusion field for general purposes, but in principle, other more specific numerical inference method can also be used.[18, 19, 20, 21] Another extension would be to generalize our results to non-Euclidean latent space to capture topological properties of certain datasets. In principle, this can be done through some geometric constraints or direct projections to the target manifold.[22]

Methods

Calculating the sliced-Wasserstein distance

As can be seen from Eq. 1, the regularizer in VAE is the Kullback–Leibler divergence between the posterior distribution and the prior distribution , averaged over the empirical data. However, as discussed in Ref. [10], this regularizer encourages the encoded distribution for each training sample to become increasingly similar to the prior distribution, fundamentally at odds with the goal of achieving good reconstruction. This also can been shown in the limiting case where . This equality occurs if and only if , suggesting the encoder should forget all information about . This can lead to difficulties in reconstruction of the input data. Therefore, Ref. [10] proposed Wasserstein Auto-Encoders (WAE) using generative adversarial networks (GAN) to directly match the distribution of the entire encoding space to the prior . This prevents the latent representations of different data points from collapsing together, thereby promoting better reconstructions while still maintaining simple latent representations. Subsequent research introduced a much simpler solution known as Sliced Wasserstein Auto-Encoders (SWAE).[11] The trick in SWAE is to project or slice the distribution along multiple randomly chosen directions, and minimize the so-called Wasserstein distance along each of those one-dimensional spaces (which can be done even analytically). According to Ref. [11], let and be random samples from the encoded input data and the prior distribution respectively. Assuming , let be randomly sampled directions from a uniform distribution on the -sphere . Then the 1D projection of the latent representation onto is . Therefore, the regularizer in Eq. 1 can be measured by the sliced-Wasserstein distance:

(12)

where and are the indices of sorted and with respect to , correspondingly. We now can generalize this regularizer loss to learning dynamics:

(13)

where .

Discretization and resampling schemes

For simplicity, regular space clustering with Euclidean distance metric is used to discretize the latent space to obtain uniformly distributed bins .[23] More advanced distance metrics such as the Mahalanobis distance[24] and others[25] are left for future exploration. The clustering is performed in such a way that cluster centers are at least from each other according to the given metric. Then samples can be assigned to cluster centers through Voronoi partitioning. Besides, to make the algorithm pay more attention to rarely sampled regions in the latent space, we resample the dataset according to a well-tempered distribution[26] by introducing a hyperparameter :

(14)

where is the actual number of samples within the bin and is the total number of samples. For , Eq. 14 reduces to and the original dataset is used for training. For , we obtain , which means that we draw samples uniformly from each bin. In practice, we find our results are robust to the choice of , and a moderate value is chosen in this paper.

Neural network architecture and training

As shown in Fig. 1(a), the full algorithm consists of four neural networks: the encoder , the decoder , and the neural networks learning the diffusion field and the force field . For non-image datasets, we use fully connected neural networks with 2 hidden layers and ReLU activations to parameterize the encoder and decoder. For image datasets, we use a convolutional neural network for the encoder and a deconvolutional neural network for the decoder. The diffusion and force field in all models are parameterized using the same fully connected neural networks with 3 hidden layers and tanh activation functions. The networks were trained using two Adam optimizers with the same learning rate of 0.001 optimizing the two loss functions shown in Eq. 9 and 11 separately. The implementation details are in SI.

1:a long unbiased trajectory with the input , latent space dimensionality , number of projections to approximate sliced Wasserstein distance, minimal distances between cluster centers , batch size , ,
2:for epoch  do
3:     if epoch then
4:         Encode the input data
5:         Discretize the latent space into bins
6:         Resample dataset by drawing samples from each bin
7:     end if
8:     for  do
9:         Sample a minibatch of size from a randomly picked bin
10:         Encode the input samples
11:         Generate prior samples
12:         Calculate the objective function
13:         Update the neural network parameters , through backpropagation
14:         Calculate the objective function
15:         Update the neural network parameters through backpropagation
16:     end for
17:end for
Algorithm 1 Dynamics constrained representation learning

Code availability

The python codes using Pytorch will be made available for public use at https://github.com/tiwarylab.

Data availability

The data that support the findings of this study are available from the corresponding author upon reasonable request.

References

Acknowledgements

This research was entirely supported by the U.S. Department of Energy, Office of Science, Basic Energy Sciences, CPIMS Program, under Award DE-SC0021009. This work used the Extreme Science and Engineering Discovery Environment (xSEDE) Bridges through allocation Grant No. CHE180027P, which is supported by the National Science Foundation Grant No. ACI-1548562. We also thank MARCC’s Bluecrab HPC clusters for computing resources.

Author contributions statement

D.W. and P.T. designed research; D.W, Y.W. and L.E performed research; D.W. analyzed data and wrote the codes. The manuscript was written through contributions of all authors. All authors have given approval to the final version of the manuscript.

Additional information

To include, in this order: Accession codes (where applicable); Competing interests (mandatory statement).

The authors have no competing interest.