skip to main content


Title: Heterogeneous Ensemble Knowledge Transfer for Training Large Models in Federated Learning
Federated learning (FL) enables edge-devices to collaboratively learn a model without disclosing their private data to a central aggregating server. Most existing FL algorithms require models of identical architecture to be deployed across the clients and server, making it infeasible to train large models due to clients' limited system resources. In this work, we propose a novel ensemble knowledge transfer method named Fed-ET in which small models (different in architecture) are trained on clients, and used to train a larger model at the server. Unlike in conventional ensemble learning, in FL the ensemble can be trained on clients' highly heterogeneous data. Cognizant of this property, Fed-ET uses a weighted consensus distillation scheme with diversity regularization that efficiently extracts reliable consensus from the ensemble while improving generalization by exploiting the diversity within the ensemble. We show the generalization bound for the ensemble of weighted models trained on heterogeneous datasets that supports the intuition of Fed-ET. Our experiments on image and language tasks show that Fed-ET significantly outperforms other state-of-the-art FL algorithms with fewer communicated parameters, and is also robust against high data-heterogeneity.  more » « less
Award ID(s):
2107024
NSF-PAR ID:
10356194
Author(s) / Creator(s):
; ; ; ;
Date Published:
Journal Name:
Proceedings of the Thirty-First International Joint Conference on Artificial Intelligence (IJCAI) Main Track
Format(s):
Medium: X
Sponsoring Org:
National Science Foundation
More Like this
  1. The Deep Operator Network (DeepONet) framework is a different class of neural network architecture that one trains to learn nonlinear operators, i.e., mappings between infinite-dimensional spaces. Traditionally, DeepONets are trained using a centralized strategy that requires transferring the training data to a centralized location. Such a strategy, however, limits our ability to secure data privacy or use high-performance distributed/parallel computing platforms. To alleviate such limitations, in this paper, we study the federated training of DeepONets for the first time. That is, we develop a framework, which we refer to as Fed-DeepONet, that allows multiple clients to train DeepONets collaboratively under the coordination of a centralized server. To achieve Fed-DeepONets, we propose an efficient stochastic gradient-based algorithm that enables the distributed optimization of the DeepONet parameters by averaging first-order estimates of the DeepONet loss gradient. Then, to accelerate the training convergence of Fed-DeepONets, we propose a moment-enhanced (i.e., adaptive) stochastic gradient-based strategy. Finally, we verify the performance of Fed-DeepONet by learning, for different configurations of the number of clients and fractions of available clients, (i) the solution operator of a gravity pendulum and (ii) the dynamic response of a parametric library of pendulums. 
    more » « less
  2. A central theme in federated learning (FL) is the fact that client data distributions are often not independent and identically distributed (IID), which has strong implications on the training process. While most existing FL algorithms focus on the conventional non-IID setting of class imbalance or missing classes across clients, in practice, the distribution differences could be more complex, e.g., changes in class conditional (domain) distributions. In this paper, we consider this complex case in FL wherein each client has access to only one domain distribution. For tasks such as domain generalization, most existing learning algorithms require access to data from multiple clients (i.e., from multiple domains) during training, which is prohibitive in FL. To address this challenge, we propose a federated domain translation method that generates pseudodata for each client which could be useful for multiple downstream learning tasks. We empirically demonstrate that our translation model is more resource-efficient (in terms of both communication and computation) and easier to train in an FL setting than standard domain translation methods. Furthermore, we demonstrate that the learned translation model enables use of state-of-the-art domain generalization methods in a federated setting, which enhances accuracy and robustness to increases in the synchronization period compared to existing methodology. 
    more » « less
  3. Solar flare prediction is a central problem in space weather forecasting and has captivated the attention of a wide spectrum of researchers due to recent advances in both remote sensing as well as machine learning and deep learning approaches. The experimental findings based on both machine and deep learning models reveal significant performance improvements for task specific datasets. Along with building models, the practice of deploying such models to production environments under operational settings is a more complex and often time-consuming process which is often not addressed directly in research settings. We present a set of new heuristic approaches to train and deploy an operational solar flare prediction system for ≥M1.0-class flares with two prediction modes: full-disk and active region-based. In full-disk mode, predictions are performed on full-disk line-of-sight magnetograms using deep learning models whereas in active region-based models, predictions are issued for each active region individually using multivariate time series data instances. The outputs from individual active region forecasts and full-disk predictors are combined to a final full-disk prediction result with a meta-model. We utilized an equal weighted average ensemble of two base learners’ flare probabilities as our baseline meta learner and improved the capabilities of our two base learners by training a logistic regression model. The major findings of this study are: 1) We successfully coupled two heterogeneous flare prediction models trained with different datasets and model architecture to predict a full-disk flare probability for next 24 h, 2) Our proposed ensembling model, i.e., logistic regression, improves on the predictive performance of two base learners and the baseline meta learner measured in terms of two widely used metrics True Skill Statistic (TSS) and Heidke Skill Score (HSS), and 3) Our result analysis suggests that the logistic regression-based ensemble (Meta-FP) improves on the full-disk model (base learner) by ∼9% in terms TSS and ∼10% in terms of HSS. Similarly, it improves on the AR-based model (base learner) by ∼17% and ∼20% in terms of TSS and HSS respectively. Finally, when compared to the baseline meta model, it improves on TSS by ∼10% and HSS by ∼15%. 
    more » « less
  4. Federated learning (FL) involves training a model over massive distributed devices, while keeping the training data localized and private. This form of collaborative learning exposes new tradeoffs among model convergence speed, model accuracy, balance across clients, and communication cost, with new challenges including: (1) straggler problem—where clients lag due to data or (computing and network) resource heterogeneity, and (2) communication bottleneck—where a large number of clients communicate their local updates to a central server and bottleneck the server. Many existing FL methods focus on optimizing along only one single dimension of the tradeoff space. Existing solutions use asynchronous model updating or tiering-based, synchronous mechanisms to tackle the straggler problem. However, asynchronous methods can easily create a communication bottleneck, while tiering may introduce biases that favor faster tiers with shorter response latencies. To address these issues, we present FedAT, a novel Federated learning system with Asynchronous Tiers under Non-i.i.d. training data. FedAT synergistically combines synchronous, intra-tier training and asynchronous, cross-tier training. By bridging the synchronous and asynchronous training through tiering, FedAT minimizes the straggler effect with improved convergence speed and test accuracy. FedAT uses a straggler-aware, weighted aggregation heuristic to steer and balance the training across clients for further accuracy improvement. FedAT compresses uplink and downlink communications using an efficient, polyline-encoding-based compression algorithm, which minimizes the communication cost. Results show that FedAT improves the prediction performance by up to 21.09% and reduces the communication cost by up to 8.5×, compared to state-of-the-art FL methods. 
    more » « less
  5. Federated Learning (FL) enables multiple clients to collaboratively learn a machine learning model without exchanging their own local data. In this way, the server can exploit the computational power of all clients and train the model on a larger set of data samples among all clients. Although such a mechanism is proven to be effective in various fields, existing works generally assume that each client preserves sufficient data for training. In practice, however, certain clients may only contain a limited number of samples (i.e., few-shot samples). For example, the available photo data taken by a specific user with a new mobile device is relatively rare. In this scenario, existing FL efforts typically encounter a significant performance drop on these clients. Therefore, it is urgent to develop a few-shot model that can generalize to clients with limited data under the FL scenario. In this paper, we refer to this novel problem as federated few-shot learning. Nevertheless, the problem remains challenging due to two major reasons: the global data variance among clients (i.e., the difference in data distributions among clients) and the local data insufficiency in each client (i.e., the lack of adequate local data for training). To overcome these two challenges, we propose a novel federated few-shot learning framework with two separately updated models and dedicated training strategies to reduce the adverse impact of global data variance and local data insufficiency. Extensive experiments on four prevalent datasets that cover news articles and images validate the effectiveness of our framework compared with the state-of-the-art baselines. 
    more » « less