Publication date
05/15/2024
Authors
Alexander Korotin Marat Khamadeev
Share

Energy-based models made Schrödinger Bridge lighter


Since the early 2000s, energy-based models (EBMs) have been gaining popularity in machine learning. They are based on the application of the canonical ensemble idea that emerged in statistical physics when working with probability distributions within the model. This is a system in thermal equilibrium with its environment due to which the probability of finding it in a certain configuration depends on the energy following the Boltzmann distribution law.

In energy-based machine learning models, a similar relationship is established between the probability distribution and some energy potential (unnormalized likelihood function). Working with a distribution expressed through such a function offers several advantages, such as simplicity, stability, adaptability, and more. Energy-based neural networks have also proven useful in tasks of image generation.

Concurrently, there is active development in generation approaches based on optimal transport (OT). Optimal transport refers to a class of tasks concerning efficient transitions between probability distributions. A team of researchers from AIRI and Skoltech led by Evgeny Burnaev has made significant progress in this direction. For instance, they proposed a rigorous formulation for the problem of achieving optimal transport with entropy regularization using neural networks (Entropic Neural Optimal Transport, ENOT) and mathematically justified searching for the theoretically best domain translation using non-paired training sets.

This time, the scientists introduced a new methodology that leverages EBM achievements to enhance ENOT. Initially, they mathematically developed an optimization procedure and a corresponding algorithm that implicitly reconstructs optimal transport plans using an energy representation. The authors also conducted a thorough theoretical analysis of the method.

To demonstrate the advantages of this new approach, the team conducted a series of experiments with EBM-augmented models on a toy 2D scenario, Gaussian-to-Gaussian and high-dimensional AFHQ Cat/Wild→Dog image transformation problems. The experiments showed that the proposed innovation allows for achieving FID metric values comparable to baselines.

AFHQ 512 × 512 Cat→Dog unpaired translation by the new Energy-guided EOT solver applied in the latent space of StyleGAN2-ADA

It turned out that the energy-based approach allows for an even more efficient solution to the problem of building a Schrödinger Bridge (SB), i.e., finding the most probable stochastic transition from one distribution to another. This became possible because the SB problem with some clarifications is equivalent to EOT. Earlier, researchers had already shared a benchmark for testing solvers that build Schrödinger Bridges using neural networks.

Almost all existing EOT/SB solvers have complex neural networks parameterization and many hyper-parameters and, as a consequence, require time-consuming training/inference procedures. While this might be acceptable for large-scale generative modeling tasks, these techniques look too heavy-weighted when one deals with moderate-dimensional data distributions, e.g., those appearing in perspective biological applications of OT.

The authors of the new study were able to overcome this difficulty by creating a lightweight Schrödinger bridge solver (LightSB) based on two ideas: parameterizing Schrödinger potentials with Gaussian mixtures and using them as an energy function. Theoretical analysis showed that the solver can serve as a universal approximator for SB. Its convergence is guaranteed as the estimation error vanishes at the standard parametric rate with the increase of the sample size.

As a practical test, researchers applied the new method in a series of experiments with synthetic and real data, including human faces and biological data. It turned out that the speed of building a SB with the new solver takes only a few minutes, as there is no longer a need for max-min optimization, simulation of the full process trajectories, and other time-consuming procedures. 

Unpaired male → female translation by LightSB solver applied in the latent space of ALAE for 1024x1024 FFHQ

Both results were accepted at the ICLR 2024 conference, and details can be found in the conference proceedings (EBM-ENOT and LightSB). The research code is available on GitHub (EBM-ENOT and LightSB).


Join AIRI