skip to main content
US FlagAn official website of the United States government
dot gov icon
Official websites use .gov
A .gov website belongs to an official government organization in the United States.
https lock icon
Secure .gov websites use HTTPS
A lock ( lock ) or https:// means you've safely connected to the .gov website. Share sensitive information only on official, secure websites.


Title: Scalable Computations of Wasserstein Barycenter via Input Convex Neural Networks
Wasserstein Barycenter is a principled approach to represent the weighted mean of a given set of probability distributions, utilizing the geometry induced by optimal transport. In this work, we present a novel scalable algorithm to approximate the Wasserstein Barycenters aiming at highdimensional applications in machine learning. Our proposed algorithm is based on the Kantorovich dual formulation of the Wasserstein-2 distance as well as a recent neural network architecture, input convex neural network, that is known to parametrize convex functions. The distinguishing features of our method are: i) it only requires samples from the marginal distributions; ii) unlike the existing approaches, it represents the Barycenter with a generative model and can thus generate infinite samples from the barycenter without querying the marginal distributions; iii) it works similar to Generative Adversarial Model in one marginal case. We demonstrate the efficacy of our algorithm by comparing it with the state-of-art methods in multiple experiments.  more » « less
Award ID(s):
2008513 1942523
PAR ID:
10295317
Author(s) / Creator(s):
; ;
Date Published:
Journal Name:
Proceedings of Machine Learning Research
Volume:
139
ISSN:
2640-3498
Page Range / eLocation ID:
1571-1581
Format(s):
Medium: X
Sponsoring Org:
National Science Foundation
More Like this
  1. Optimal transport has emerged as a powerful tool for a variety of problems in machine learning, and it is frequently used to enforce distributional constraints. In this context, existing methods often use either a Wasserstein metric, or else they apply concurrent barycenter approaches when more than two distributions are considered. In this paper, we leverage multi-marginal optimal transport (MMOT), where we take advantage of a procedure that computes a generalized earth mover’s distance as a sub-routine. We show that not only is our algorithm computationally more efficient compared to other barycentric-based distance methods, but it has the additional advantage that gradients used for backpropagation can be efficiently computed during the forward pass computation itself, which leads to substantially faster model training. We provide technical details about this new regularization term and its properties, and we present experimental demonstrations of faster runtimes when compared to standard Wasserstein-style methods. Finally, on a range of experiments designed to assess effectiveness at enforcing fairness, we demonstrate our method compares well with alternatives. 
    more » « less
  2. Abstract One key challenge encountered in single-cell data clustering is to combine clustering results of data sets acquired from multiple sources. We propose to represent the clustering result of each data set by a Gaussian mixture model (GMM) and produce an integrated result based on the notion of Wasserstein barycenter. However, the precise barycenter of GMMs, a distribution on the same sample space, is computationally infeasible to solve. Importantly, the barycenter of GMMs may not be a GMM containing a reasonable number of components. We thus propose to use the minimized aggregated Wasserstein (MAW) distance to approximate the Wasserstein metric and develop a new algorithm for computing the barycenter of GMMs under MAW. Recent theoretical advances further justify using the MAW distance as an approximation for the Wasserstein metric between GMMs. We also prove that the MAW barycenter of GMMs has the same expectation as the Wasserstein barycenter. Our proposed algorithm for clustering integration scales well with the data dimension and the number of mixture components, with complexity independent of data size. We demonstrate that the new method achieves better clustering results on several single-cell RNA-seq data sets than some other popular methods. 
    more » « less
  3. Abernethy, Jacob; Agarwal, Shivani (Ed.)
    We study first order methods to compute the barycenter of a probability distribution $$P$$ over the space of probability measures with finite second moment. We develop a framework to derive global rates of convergence for both gradient descent and stochastic gradient descent despite the fact that the barycenter functional is not geodesically convex. Our analysis overcomes this technical hurdle by employing a Polyak-Ł{}ojasiewicz (PL) inequality and relies on tools from optimal transport and metric geometry. In turn, we establish a PL inequality when $$P$$ is supported on the Bures-Wasserstein manifold of Gaussian probability measures. It leads to the first global rates of convergence for first order methods in this context. 
    more » « less
  4. We give a new algorithm for learning a two-layer neural network under a general class of input distributions. Assuming there is a ground-truth two-layer network y = Aσ(Wx) + ξ, where A,W are weight matrices, ξ represents noise, and the number of neurons in the hidden layer is no larger than the input or output, our algorithm is guaranteed to recover the parameters A,W of the ground-truth network. The only requirement on the input x is that it is symmetric, which still allows highly complicated and structured input. Our algorithm is based on the method-of-moments framework and extends several results in tensor decompositions. We use spectral algorithms to avoid the complicated non-convex optimization in learning neural networks. Experiments show that our algorithm can robustly learn the ground-truth neural network with a small number of samples for many symmetric input distributions. 
    more » « less
  5. Wasserstein distance plays increasingly important roles in machine learning, stochastic programming and image processing. Major efforts have been under way to address its high computational complexity, some leading to approximate or regularized variations such as Sinkhorn distance. However, as we will demonstrate, regularized variations with large regularization parameter will degradate the performance in several important machine learning applications, and small regularization parameter will fail due to numerical stability issues with existing algorithms. We address this challenge by developing an Inexact Proximal point method for exact Optimal Transport problem (IPOT) with the proximal operator approximately evaluated at each iteration using projections to the probability simplex. The algorithm (a) converges to exact Wasserstein distance with theoretical guarantee and robust regularization parameter selection, (b) alleviates numerical stability issue, (c) has similar computational complexity to Sinkhorn, and (d) avoids the shrinking problem when apply to generative models. Furthermore, a new algorithm is proposed based on IPOT to obtain sharper Wasserstein barycenter. 
    more » « less