Publication date
01/23/2025
Authors
Marat Khamadeev Maksim Bobrin
Share

Expectile regularization accelerated optimal transport tenfold


Approaches based on optimal transport (OT) have provided a new perspective on a variety of machine learning tasks, ranging from generative modeling to domain adaptation. These methods operate in the language of probability distributions and transitions between them.

The task of OT typically involves finding the optimal mapping (optimal plan) between certain distributions for a given cost function. In practice, the optimum is sought through explicit analytical solutions to the Kantorovich problem, approximating the optimal solution via a parametric function. However, this approach can exhibit instability when searching for optimal potentials, one of which is dual (conjugate) to the other according to the cost function.

A rough estimate of the dual potential can lead to a situation where the sum of potentials diverges, while an exact solution results in an additional optimization problem. Existing methods for addressing this issue require extensive tuning of neural network hyperparameters or the use of iterative searches searches for a more accurate solution to the inner problem at each optimization step, making them computationally expensive.

Researchers from the AIRI Institute and Skoltech proposed to alleviate the aforementioned difficulties by restricting the class of solutions for the inner optimization of conjugate potentials through an expectile regularization, which is an asymmetric generalization of the mean squared error (MSE) approach. The idea behind this approach is to move the inner optimization (concerning the conjugate potential) into the outer original Kantorovich problem. The new optimization task ensures stable and balanced convergence for both finding optimal potentials and approximating the transport plan. The new method is called Expectile-Regularized Neural Optimal Transport (ENOT).

The authors aimed to evaluate the effectiveness of ENOT by comparing it with existing methods using a popular benchmark for constructing Wasserstein-2 OT plans, which was also developed with contributions from AIRI researchers. The benchmark requires models to establish a connection between several synthetic paired datasets of different dimensions and sets of images, with the optimal plan known in advance. The researchers experimented with different cost functionals and assessed the performance of their method in more complex generative img2img tasks.


An example of generating 3D objects from noise using several methods, including ENOT

The evaluation demonstrated that the new method accurately reconstructs the transport plan while requiring significantly fewer computational resources, less time for parameter tuning, and fewer iterations compared to baselines. In several cases, the application of expectile regularization resulted in a tenfold acceleration compared to SOTA approaches.

ENOT is implemented on the JAX framework and has already been integrated into one of the most popular libraries for computing OT, ott-jax, which facilitates easy use of the method. The authors presented their work at NeurIPS 2024, where their paper was selected for Spotlight –– a special track for works particularly highlighted by reviewers. A brief overview of the method can be found at the provided link.

Join AIRI