Wasserstein Barycenter is a principled approach to represent the weighted mean of a given set of probability distributions, utilizing the geometry induced by optimal transport. In this work, we present a novel scalable algorithm to approximate the Wasserstein Barycenters aiming at highdimensional applications in machine learning. Our proposed algorithm is based on the Kantorovich dual formulation of the Wasserstein-2 distance as well as a recent neural network architecture, input convex neural network, that is known to parametrize convex functions. The distinguishing features of our method are: i) it only requires samples from the marginal distributions; ii) unlike the existing approaches, it represents the Barycenter with a generative model and can thus generate infinite samples from the barycenter without querying the marginal distributions; iii) it works similar to Generative Adversarial Model in one marginal case. We demonstrate the efficacy of our algorithm by comparing it with the state-of-art methods in multiple experiments.
more »
« less
This content will become publicly available on July 13, 2026
Optimal Transport Barycenter via Nonconvex-Concave Minimax Optimization
The optimal transport barycenter (a.k.a. Wasserstein barycenter) is a fundamental notion of averaging that extends from the Euclidean space to the Wasserstein space of probability distributions. Computation of the unregularized barycenter for discretized probability distributions on point clouds is a challenging task when the domain dimension d>1. Most practical algorithms for approximating the barycenter problem are based on entropic regularization. In this paper, we introduce a nearly linear time O(mlogm) and linear space complexity O(m) primal-dual algorithm, the Wasserstein-Descent ℍ˙1-Ascent (WDHA) algorithm, for computing the exact barycenter when the input probability density functions are discretized on an m-point grid. The key success of the WDHA algorithm hinges on alternating between two different yet closely related Wasserstein and Sobolev optimization geometries for the primal barycenter and dual Kantorovich potential subproblems. Under reasonable assumptions, we establish the convergence rate and iteration complexity of WDHA to its stationary point when the step size is appropriately chosen. Superior computational efficacy, scalability, and accuracy over the existing Sinkhorn-type algorithms are demonstrated on high-resolution (e.g., 1024×1024 images) 2D synthetic and real data.
more »
« less
- Award ID(s):
- 2347760
- PAR ID:
- 10621410
- Publisher / Repository:
- International Conference on Machine Learning
- Date Published:
- Format(s):
- Medium: X
- Location:
- Vancouver
- Sponsoring Org:
- National Science Foundation
More Like this
-
-
Wasserstein distance plays increasingly important roles in machine learning, stochastic programming and image processing. Major efforts have been under way to address its high computational complexity, some leading to approximate or regularized variations such as Sinkhorn distance. However, as we will demonstrate, regularized variations with large regularization parameter will degradate the performance in several important machine learning applications, and small regularization parameter will fail due to numerical stability issues with existing algorithms. We address this challenge by developing an Inexact Proximal point method for exact Optimal Transport problem (IPOT) with the proximal operator approximately evaluated at each iteration using projections to the probability simplex. The algorithm (a) converges to exact Wasserstein distance with theoretical guarantee and robust regularization parameter selection, (b) alleviates numerical stability issue, (c) has similar computational complexity to Sinkhorn, and (d) avoids the shrinking problem when apply to generative models. Furthermore, a new algorithm is proposed based on IPOT to obtain sharper Wasserstein barycenter.more » « less
-
We consider synthesis and analysis of probability measures using the entropy-regularized Wasserstein-2 cost and its unbiased version, the Sinkhorn divergence. The synthesis problem consists of computing the barycenter, with respect to these costs, of m reference measures given a set of coefficients belonging to the m-dimensional simplex. The analysis problem consists of finding the coefficients for the closest barycenter in the Wasserstein-2 distance to a given measure μ. Under the weakest assumptions on the measures thus far in the literature, we compute the derivative of the entropy-regularized Wasserstein-2 cost. We leverage this to establish a characterization of regularized barycenters as solutions to a fixed-point equation for the average of the entropic maps from the barycenter to the reference measures. This characterization yields a finite-dimensional, convex, quadratic program for solving the analysis problem when μ is a barycenter. It is shown that these coordinates, as well as the value of the barycenter functional, can be estimated from samples with dimension-independent rates of convergence, a hallmark of entropy-regularized optimal transport, and we verify these rates experimentally. We also establish that barycentric coordinates are stable with respect to perturbations in the Wasserstein-2 metric, suggesting a robustness of these coefficients to corruptions. We employ the barycentric coefficients as features for classification of corrupted point cloud data, and show that compared to neural network baselines, our approach is more efficient in small training data regimes.more » « less
-
We propose a new primal-dual homotopy smoothing algorithm for a linearly constrained convex program, where neither the primal nor the dual function has to be smooth or strongly convex. The best known iteration complexity solving such a non-smooth problem is O(ε−1). In this paper, we show that by leveraging a local error bound condition on the dual function, the proposed algorithm can achieve a better primal convergence time of O ε−2/(2+β) log2(ε−1), where β ∈ (0, 1] is a local error bound parameter. As an example application of the general algorithm, we show that the distributed geometric median problem, which can be formulated as a constrained convex program, has its dual function non-smooth but satisfying the aforementioned local error bound condition with β = 1/2, therefore enjoying a convergence time of O ε−4/5 log2(ε−1). This result improves upon the O(ε−1) convergence time bound achieved by existing distributed optimization algorithms. Simulation experiments also demonstrate the performance of our proposed algorithm.more » « less
-
Optimal transport (OT) offers a versatile framework to compare complex data distributions in a geometrically meaningful way. Traditional methods for computing the Wasserstein distance and geodesic between probability measures require mesh-specific domain discretization and suffer from the curse-of-dimensionality. We present GeONet, a mesh-invariant deep neural operator network that learns the non-linear mapping from the input pair of initial and terminal distributions to the Wasserstein geodesic connecting the two endpoint distributions. In the offline training stage, GeONet learns the saddle point optimality conditions for the dynamic formulation of the OT problem in the primal and dual spaces that are characterized by a coupled PDE system. The subsequent inference stage is instantaneous and can be deployed for real-time predictions in the online learning setting. We demonstrate that GeONet achieves comparable testing accuracy to the standard OT solvers on simulation examples and the MNIST dataset with considerably reduced inference-stage computational cost by orders of magnitude.more » « less
An official website of the United States government
