skip to main content


Title: Scalable Computations of Wasserstein Barycenter via Input Convex Neural Networks
Wasserstein Barycenter is a principled approach to represent the weighted mean of a given set of probability distributions, utilizing the geometry induced by optimal transport. In this work, we present a novel scalable algorithm to approximate the Wasserstein Barycenters aiming at highdimensional applications in machine learning. Our proposed algorithm is based on the Kantorovich dual formulation of the Wasserstein-2 distance as well as a recent neural network architecture, input convex neural network, that is known to parametrize convex functions. The distinguishing features of our method are: i) it only requires samples from the marginal distributions; ii) unlike the existing approaches, it represents the Barycenter with a generative model and can thus generate infinite samples from the barycenter without querying the marginal distributions; iii) it works similar to Generative Adversarial Model in one marginal case. We demonstrate the efficacy of our algorithm by comparing it with the state-of-art methods in multiple experiments.  more » « less
Award ID(s):
2008513 1942523
NSF-PAR ID:
10295317
Author(s) / Creator(s):
; ;
Date Published:
Journal Name:
Proceedings of Machine Learning Research
Volume:
139
ISSN:
2640-3498
Page Range / eLocation ID:
1571-1581
Format(s):
Medium: X
Sponsoring Org:
National Science Foundation
More Like this
  1. Optimal transport has emerged as a powerful tool for a variety of problems in machine learning, and it is frequently used to enforce distributional constraints. In this context, existing methods often use either a Wasserstein metric, or else they apply concurrent barycenter approaches when more than two distributions are considered. In this paper, we leverage multi-marginal optimal transport (MMOT), where we take advantage of a procedure that computes a generalized earth mover’s distance as a sub-routine. We show that not only is our algorithm computationally more efficient compared to other barycentric-based distance methods, but it has the additional advantage that gradients used for backpropagation can be efficiently computed during the forward pass computation itself, which leads to substantially faster model training. We provide technical details about this new regularization term and its properties, and we present experimental demonstrations of faster runtimes when compared to standard Wasserstein-style methods. Finally, on a range of experiments designed to assess effectiveness at enforcing fairness, we demonstrate our method compares well with alternatives. 
    more » « less
  2. Abstract

    One key challenge encountered in single-cell data clustering is to combine clustering results of data sets acquired from multiple sources. We propose to represent the clustering result of each data set by a Gaussian mixture model (GMM) and produce an integrated result based on the notion of Wasserstein barycenter. However, the precise barycenter of GMMs, a distribution on the same sample space, is computationally infeasible to solve. Importantly, the barycenter of GMMs may not be a GMM containing a reasonable number of components. We thus propose to use the minimized aggregated Wasserstein (MAW) distance to approximate the Wasserstein metric and develop a new algorithm for computing the barycenter of GMMs under MAW. Recent theoretical advances further justify using the MAW distance as an approximation for the Wasserstein metric between GMMs. We also prove that the MAW barycenter of GMMs has the same expectation as the Wasserstein barycenter. Our proposed algorithm for clustering integration scales well with the data dimension and the number of mixture components, with complexity independent of data size. We demonstrate that the new method achieves better clustering results on several single-cell RNA-seq data sets than some other popular methods.

     
    more » « less
  3. Wasserstein distance plays increasingly important roles in machine learning, stochastic programming and image processing. Major efforts have been under way to address its high computational complexity, some leading to approximate or regularized variations such as Sinkhorn distance. However, as we will demonstrate, regularized variations with large regularization parameter will degradate the performance in several important machine learning applications, and small regularization parameter will fail due to numerical stability issues with existing algorithms. We address this challenge by developing an Inexact Proximal point method for exact Optimal Transport problem (IPOT) with the proximal operator approximately evaluated at each iteration using projections to the probability simplex. The algorithm (a) converges to exact Wasserstein distance with theoretical guarantee and robust regularization parameter selection, (b) alleviates numerical stability issue, (c) has similar computational complexity to Sinkhorn, and (d) avoids the shrinking problem when apply to generative models. Furthermore, a new algorithm is proposed based on IPOT to obtain sharper Wasserstein barycenter. 
    more » « less
  4. Abstract Generative Adversarial Networks trained on samples of simulated or actual events have been proposed as a way of generating large simulated datasets at a reduced computational cost. In this work, a novel approach to perform the simulation of photodetector signals from the time projection chamber of the EXO-200 experiment is demonstrated. The method is based on a Wasserstein Generative Adversarial Network — a deep learning technique allowing for implicit non-parametric estimation of the population distribution for a given set of objects. Our network is trained on real calibration data using raw scintillation waveforms as input. We find that it is able to produce high-quality simulated waveforms an order of magnitude faster than the traditional simulation approach and, importantly, generalize from the training sample and discern salient high-level features of the data. In particular, the network correctly deduces position dependency of scintillation light response in the detector and correctly recognizes dead photodetector channels. The network output is then integrated into the EXO-200 analysis framework to show that the standard EXO-200 reconstruction routine processes the simulated waveforms to produce energy distributions comparable to that of real waveforms. Finally, the remaining discrepancies and potential ways to improve the approach further are highlighted. 
    more » « less
  5. We give a new algorithm for learning a two-layer neural network under a general class of input distributions. Assuming there is a ground-truth two-layer network y = Aσ(Wx) + ξ, where A,W are weight matrices, ξ represents noise, and the number of neurons in the hidden layer is no larger than the input or output, our algorithm is guaranteed to recover the parameters A,W of the ground-truth network. The only requirement on the input x is that it is symmetric, which still allows highly complicated and structured input. Our algorithm is based on the method-of-moments framework and extends several results in tensor decompositions. We use spectral algorithms to avoid the complicated non-convex optimization in learning neural networks. Experiments show that our algorithm can robustly learn the ground-truth neural network with a small number of samples for many symmetric input distributions. 
    more » « less