skip to main content


Title: Meta-Learning with Implicit Gradients
A core capability of intelligent systems is the ability to quickly learn new tasks by drawing on prior experience. Gradient (or optimization) based meta-learning has recently emerged as an effective approach for few-shot learning. In this formulation, meta-parameters are learned in the outer loop, while task-specific models are learned in the inner-loop, by using only a small amount of data from the current task. A key challenge in scaling these approaches is the need to differentiate through the inner loop learning process, which can impose considerable computational and memory burdens. By drawing upon implicit differentiation, we develop the implicit MAML algorithm, which depends only on the solution to the inner level optimization and not the path taken by the inner loop optimizer. This effectively decouples the meta-gradient computation from the choice of inner loop optimizer. As a result, our approach is agnostic to the choice of inner loop optimizer and can gracefully handle many gradient steps without vanishing gradients or memory constraints. Theoretically, we prove that implicit MAML can compute accurate meta-gradients with a memory footprint that is, up to small constant factors, no more than that which is required to compute a single inner loop gradient and at no overall increase in the total computational cost. Experimentally, we show that these benefits of implicit MAML translate into empirical gains on few-shot image recognition benchmarks.  more » « less
Award ID(s):
1740551 1703574
NSF-PAR ID:
10184688
Author(s) / Creator(s):
; ; ;
Date Published:
Journal Name:
Advances in neural information processing systems
ISSN:
1049-5258
Format(s):
Medium: X
Sponsoring Org:
National Science Foundation
More Like this
  1. null (Ed.)
    The problem of learning to generalize on unseen classes during the training step, also known as few-shot classification, has attracted considerable attention. Initialization based methods, such as the gradient-based model agnostic meta-learning (MAML) [1], tackle the few-shot learning problem by “learning to fine-tune”. The goal of these approaches is to learn proper model initialization so that the classifiers for new classes can be learned from a few labeled examples with a small number of gradient update steps. Few shot meta-learning is well-known with its fast-adapted capability and accuracy generalization onto unseen tasks [2]. Learning fairly with unbiased outcomes is another significant hallmark of human intelligence, which is rarely touched in few-shot meta-learning. In this work, we propose a novel Primal-Dual Fair Meta-learning framework, namely PDFM, which learns to train fair machine learning models using only a few examples based on data from related tasks. The key idea is to learn a good initialization of a fair model’s primal and dual parameters so that it can adapt to a new fair learning task via a few gradient update steps. Instead of manually tuning the dual parameters as hyperparameters via a grid search, PDFM optimizes the initialization of the primal and dual parameters jointly for fair meta-learning via a subgradient primal-dual approach. We further instantiate an example of bias controlling using decision boundary covariance (DBC) [3] as the fairness constraint for each task, and demonstrate the versatility of our proposed approach by applying it to classification on a variety of three real-world datasets. Our experiments show substantial improvements over the best prior work for this setting. 
    more » « less
  2. Model-Agnostic Meta-Learning (MAML), a popular gradient-based meta-learning framework, assumes that the contribution of each task or instance to the meta-learner is equal.Hence, it fails to address the domain shift between base and novel classes in few-shot learning. In this work, we propose a novel robust meta-learning algorithm, NESTEDMAML, which learns to assign weights to training tasks or instances. We con-sider weights as hyper-parameters and iteratively optimize them using a small set of validation tasks set in a nested bi-level optimization approach (in contrast to the standard bi-level optimization in MAML). We then applyNESTED-MAMLin the meta-training stage, which involves (1) several tasks sampled from a distribution different from the meta-test task distribution, or (2) some data samples with noisy labels.Extensive experiments on synthetic and real-world datasets demonstrate that NESTEDMAML efficiently mitigates the effects of ”unwanted” tasks or instances, leading to significant improvement over the state-of-the-art robust meta-learning methods. 
    more » « less
  3. Tasks across diverse application domains can be posed as large-scale optimization problems, these include graphics, vision, machine learning, imaging, health, scheduling, planning, and energy system forecasting. Independently of the application domain, proximal algorithms have emerged as a formal optimization method that successfully solves a wide array of existing problems, often exploiting problem-specific structures in the optimization. Although model-based formal optimization provides a principled approach to problem modeling with convergence guarantees, at first glance, this seems to be at odds with black-box deep learning methods. A recent line of work shows that, when combined with learning-based ingredients, model-based optimization methods are effective, interpretable, and allow for generalization to a wide spectrum of applications with little or no extra training data. However, experimenting with such hybrid approaches for different tasks by hand requires domain expertise in both proximal optimization and deep learning, which is often error-prone and time-consuming. Moreover, naively unrolling these iterative methods produces lengthy compute graphs, which when differentiated via autograd techniques results in exploding memory consumption, making batch-based training challenging. In this work, we introduce ∇-Prox, a domain-specific modeling language and compiler for large-scale optimization problems using differentiable proximal algorithms. ∇-Prox allows users to specify optimization objective functions of unknowns concisely at a high level, and intelligently compiles the problem into compute and memory-efficient differentiable solvers. One of the core features of ∇-Prox is its full differentiability, which supports hybrid model- and learning-based solvers integrating proximal optimization with neural network pipelines. Example applications of this methodology include learning-based priors and/or sample-dependent inner-loop optimization schedulers, learned with deep equilibrium learning or deep reinforcement learning. With a few lines of code, we show ∇-Prox can generate performant solvers for a range of image optimization problems, including end-to-end computational optics, image deraining, and compressive magnetic resonance imaging. We also demonstrate ∇-Prox can be used in a completely orthogonal application domain of energy system planning, an essential task in the energy crisis and the clean energy transition, where it outperforms state-of-the-art CVXPY and commercial Gurobi solvers. 
    more » « less
  4. Few-shot classification (FSC) requires training models using a few (typically one to five) data points per class. Meta learning has proven to be able to learn a parametrized model for FSC by training on various other classification tasks. In this work, we propose PLATINUM (semi-suPervised modeL Agnostic meTa-learnIng usiNg sUbmodular Mutual information), a novel semi-supervised model agnostic meta-learning framework that uses the submodular mutual information (SMI) functions to boost the performance of FSC. PLATINUM leverages unlabeled data in the inner and outer loop using SMI functions during meta-training and obtains richer meta-learned parameterizations for meta-test. We study the performance of PLATINUM in two scenarios - 1) where the unlabeled data points belong to the same set of classes as the labeled set of a certain episode, and 2) where there exist out-of-distribution classes that do not belong to the labeled set. We evaluate our method on various settings on the miniImageNet, tieredImageNet and Fewshot-CIFAR100 datasets. Our experiments show that PLATINUM outperforms MAML and semi-supervised approaches like pseduo-labeling for semi-supervised FSC, especially for small ratio of labeled examples per class. 
    more » « less
  5. Chaudhuri, Kamalika ; Jegelka, Stefanie ; Song, Le ; Szepesyari, Csaba ; Niu, Gang ; Sabato, Sivan (Ed.)
    Few-shot classification (FSC) requires training models using a few (typically one to five) data points per class. Meta-learning has proven to be able to learn a parametrized model for FSC by training on various other classification tasks. In this work, we propose PLATINUM (semi-suPervised modeL Agnostic meTa learnIng usiNg sUbmodular Mutual information ), a novel semi-supervised model agnostic meta learning framework that uses the submodular mutual in- formation (SMI) functions to boost the perfor- mance of FSC. PLATINUM leverages unlabeled data in the inner and outer loop using SMI func- tions during meta-training and obtains richer meta- learned parameterizations. We study the per- formance of PLATINUM in two scenarios - 1) where the unlabeled data points belong to the same set of classes as the labeled set of a cer- tain episode, and 2) where there exist out-of- distribution classes that do not belong to the la- beled set. We evaluate our method on various settings on the miniImageNet, tieredImageNet and CIFAR-FS datasets. Our experiments show that PLATINUM outperforms MAML and semi- supervised approaches like pseduo-labeling for semi-supervised FSC, especially for small ratio of labeled to unlabeled samples. 
    more » « less