skip to main content
US FlagAn official website of the United States government
dot gov icon
Official websites use .gov
A .gov website belongs to an official government organization in the United States.
https lock icon
Secure .gov websites use HTTPS
A lock ( lock ) or https:// means you've safely connected to the .gov website. Share sensitive information only on official, secure websites.


Title: Causal Markov Blanket Representation Learning for Out-of-distribution Generalization
The pursuit of generalizable representations in the realm of machine learning and computer vision is a dynamic field of research. Typically, current methods aim to secure invariant representations by either harnessing domain expertise or leveraging data from multiple domains. In this paper, we introduce a novel approach that involves acquiring Causal Markov Blanket (CMB) representations to improve prediction performance in the face of distribution shifts. Causal Markov Blanket representations comprise the direct causes and effects of the target variable, rendering them invariant across diverse domains. To elaborate, our approach commences with the introduction of a novel structural causal model (SCM) equipped with latent representations, designed to capture the underlying causal mechanisms governing the data generation process. Subsequently, we propose a CMB representation learning framework that derives representations conforming to the proposed SCM. In comparison to state-of-the-art domain generalization methods, our approach exhibits robustness and adaptability under distribution shifts  more » « less
Award ID(s):
2236026
PAR ID:
10513327
Author(s) / Creator(s):
; ; ; ;
Publisher / Repository:
NeurIPS 2024
Date Published:
Journal Name:
NeurIPS Workshop: Causal Representation Learning, 2023
Format(s):
Medium: X
Sponsoring Org:
National Science Foundation
More Like this
  1. In general, graph representation learning methods assume that the train and test data come from the same distribution. In this work we consider an underexplored area of an otherwise rapidly developing field of graph representation learning: The task of out-of-distribution (OOD) graph classification, where train and test data have different distributions, with test data unavailable during training. Our work shows it is possible to use a causal model to learn approximately invariant representations that better extrapolate between train and test data. Finally, we conclude with synthetic and real-world dataset experiments showcasing the benefits of representations that are invariant to train/test distribution shifts. 
    more » « less
  2. We study the problem of causal structure learning in linear systems from observational data given in multiple domains, across which the causal coefficients and/or the distribution of the exogenous noises may vary. The main tool used in our approach is the principle that in a causally sufficient system, the causal modules, as well as their included parameters, change independently across domains. We first introduce our approach for finding causal direction in a system comprising two variables and propose efficient methods for identifying causal direction. Then we generalize our methods to causal structure learning in networks of variables. Most of previous work in structure learning from multi-domain data assume that certain types of invariance are held in causal modules across domains. Our approach unifies the idea in those works and generalizes to the case that there is no such invariance across the domains. Our proposed methods are generally capable of identifying causal direction from fewer than ten domains. When the invariance property holds, two domains are generally sufficient. 
    more » « less
  3. The goal of domain adaptation is to train a high-performance predictive model on the target domain data by using knowledge from the source domain data, which has different but related data distribution. In this paper, we consider unsupervised domain adaptation where we have labelled source domain data but unlabelled target domain data. Our solution to unsupervised domain adaptation is to learn a domain- invariant representation that is also category discriminative. Domain- invariant representations are realized by minimizing the domain discrepancy. To minimize the domain discrepancy, we propose a novel graph- matching metric between the source and target domain representations. Minimizing this metric allows the source and target representations to be in support of each other. We further exploit confident unlabelled target domain samples and their pseudo-labels to refine our proposed model. We expect the refining step to improve the performance further. This is validated by performing experiments on standard image classification adaptation datasets. Results showed our proposed approach out-perform previous domain-invariant representation learning approaches. 
    more » « less
  4. null (Ed.)
    Adversarial learning has demonstrated good performance in the unsupervised domain adaptation setting, by learning domain-invariant representations. However, recent work has shown limitations of this approach when label distributions differ between the source and target domains. In this paper, we propose a new assumption, generalized label shift (GLS), to improve robustness against mismatched label distributions. GLS states that, conditioned on the label, there exists a representation of the input that is invariant between the source and target domains. Under GLS, we provide theoretical guarantees on the transfer performance of any classifier. We also devise necessary and sufficient conditions for GLS to hold, by using an estimation of the relative class weights between domains and an appropriate reweighting of samples. Our weight estimation method could be straightforwardly and generically applied in existing domain adaptation (DA) algorithms that learn domain-invariant representations, with small computational overhead. In particular, we modify three DA algorithms, JAN, DANN and CDAN, and evaluate their performance on standard and artificial DA tasks. Our algorithms outperform the base versions, with vast improvements for large label distribution mismatches. Our code is available at https://tinyurl.com/y585xt6j. 
    more » « less
  5. Classification models trained on data from one source may underperform when tested on data acquired from different sources due to shifts in data distributions, which limit the models’ generalizability in real-world applications. Domain adaptation methods proposed to align such shifts in source-target data distributions use contrastive learning or adversarial techniques with or without internal cluster alignment. The intracluster alignment is performed using standalone k-means clustering on image embedding. This paper introduces a novel deep clustering approach to align cluster distributions in tandem with adapting source and target data distributions. Our method learns and aligns a mixture of cluster distributions in the unlabeled target domain with those in the source domain in a unified deep representation learning framework. Experiments demonstrate that intra-cluster alignment improves classification accuracy in nine out of ten domain adaptation examples. These improvements range between 0.3% and 2.0% compared to k-means clustering of embedding and between 0.4% and 5.8% compared to methods without class-level alignment. Unlike current domain adaptation methods, the proposed cluster distribution-based deep learning provides a quantitative and explainable measure of distribution shifts in data domains. We have publicly shared the source code for the algorithm implementation. 
    more » « less