skip to main content
US FlagAn official website of the United States government
dot gov icon
Official websites use .gov
A .gov website belongs to an official government organization in the United States.
https lock icon
Secure .gov websites use HTTPS
A lock ( lock ) or https:// means you've safely connected to the .gov website. Share sensitive information only on official, secure websites.


Title: U-Statistics for Importance-Weighted Variational Inference
We propose the use of U-statistics to reduce variance for gradient estimation in importance-weighted variational inference. The key observation is that, given a base gradient estimator that requires m > 1 samples and a total of n > m samples to be used for estimation, lower variance is achieved by averaging the base estimator on overlapping batches of size m than disjoint batches, as currently done. We use classical U-statistic theory to analyze the variance reduction, and propose novel approximations with theoretical guarantees to ensure computational efficiency. We find empirically that U-statistic variance reduction can lead to modest to significant improvements in inference performance on a range of models, with little computational cost.  more » « less
Award ID(s):
1749854
PAR ID:
10473118
Author(s) / Creator(s):
; ; ;
Publisher / Repository:
Transactions on Machine Learning Research
Date Published:
Journal Name:
Transactions on machine learning research
ISSN:
2835-8856
Format(s):
Medium: X
Sponsoring Org:
National Science Foundation
More Like this
  1. Chen, Yi-Hau; Stufken, John; Judy_Wang, Huixia (Ed.)
    Though introduced nearly 50 years ago, the infinitesimal jackknife (IJ) remains a popular modern tool for quantifying predictive uncertainty in complex estimation settings. In particular, when supervised learning ensembles are constructed via bootstrap samples, recent work demonstrated that the IJ estimate of variance is particularly convenient and useful. However, despite the algebraic simplicity of its final form, its derivation is rather complex. As a result, studies clarifying the intuition behind the estimator or rigorously investigating its properties have been severely lacking. This work aims to take a step forward on both fronts. We demonstrate that surprisingly, the exact form of the IJ estimator can be obtained via a straightforward linear regression of the individual bootstrap estimates on their respective weights or via the classical jackknife. The latter realization allows us to formally investigate the bias of the IJ variance estimator and better characterize the settings in which its use is appropriate. Finally, we extend these results to the case of U-statistics where base models are constructed via subsampling rather than bootstrapping and provide a consistent estimate of the resulting variance. 
    more » « less
  2. Meila, Marina; Zhang, Tong (Ed.)
    Black-box variational inference algorithms use stochastic sampling to analyze diverse statistical models, like those expressed in probabilistic programming languages, without model-specific derivations. While the popular score-function estimator computes unbiased gradient estimates, its variance is often unacceptably large, especially in models with discrete latent variables. We propose a stochastic natural gradient estimator that is as broadly applicable and unbiased, but improves efficiency by exploiting the curvature of the variational bound, and provably reduces variance by marginalizing discrete latent variables. Our marginalized stochastic natural gradients have intriguing connections to classic coordinate ascent variational inference, but allow parallel updates of variational parameters, and provide superior convergence guarantees relative to naive Monte Carlo approximations. We integrate our method with the probabilistic programming language Pyro and evaluate real-world models of documents, images, networks, and crowd-sourcing. Compared to score-function estimators, we require far fewer Monte Carlo samples and consistently convergence orders of magnitude faster. 
    more » « less
  3. Wallach, H.; Larochelle, H.; Beygelzimer, A.; d'Alché-Buc, F.; Fox, E.; Garnett, R. (Ed.)
    Variance reduction has emerged in recent years as a strong competitor to stochastic gradient descent in non-convex problems, providing the first algorithms to improve upon the converge rate of stochastic gradient descent for finding first-order critical points. However, variance reduction techniques typically require carefully tuned learning rates and willingness to use excessively large "mega-batches" in order to achieve their improved results. We present a new algorithm, STORM, that does not require any batches and makes use of adaptive learning rates, enabling simpler implementation and less hyperparameter tuning. Our technique for removing the batches uses a variant of momentum to achieve variance reduction in non-convex optimization. On smooth losses $$F$$, STORM finds a point $$\boldsymbol{x}$$ with $$E[\|\nabla F(\boldsymbol{x})\|]\le O(1/\sqrt{T}+\sigma^{1/3}/T^{1/3})$$ in $$T$$ iterations with $$\sigma^2$$ variance in the gradients, matching the optimal rate and without requiring knowledge of $$\sigma$$. 
    more » « less
  4. Chaudhuri, Kamalika and (Ed.)
    While deep generative models have succeeded in image processing, natural language processing, and reinforcement learning, training that involves discrete random variables remains challenging due to the high variance of its gradient estimation process. Monte Carlo is a common solution used in most variance reduction approaches. However, this involves time-consuming resampling and multiple function evaluations. We propose a Gapped Straight-Through (GST) estimator to reduce the variance without incurring resampling overhead. This estimator is inspired by the essential properties of Straight-Through Gumbel-Softmax. We determine these properties and show via an ablation study that they are essential. Experiments demonstrate that the proposed GST estimator enjoys better performance compared to strong baselines on two discrete deep generative modeling tasks, MNIST-VAE and ListOps. 
    more » « less
  5. As prompts become central to Large Language Models (LLMs), optimizing them is vital. Textual Stochastic Gradient Descent (TSGD) offers a data-driven approach by iteratively refining prompts using LLM-suggested updates over minibatches. We empirically show that increasing training data initially improves but can later degrade TSGD's performance across NLP tasks, while also raising computational costs. To address this, we propose Textual Stochastic Gradient Descent with Momentum (TSGD-M)—a scalable method that reweights prompt sampling based on past batches. Evaluated on 9 NLP tasks across three domains, TSGD-M outperforms TSGD baselines for most tasks and reduces performance variance. 
    more » « less