skip to main content
US FlagAn official website of the United States government
dot gov icon
Official websites use .gov
A .gov website belongs to an official government organization in the United States.
https lock icon
Secure .gov websites use HTTPS
A lock ( lock ) or https:// means you've safely connected to the .gov website. Share sensitive information only on official, secure websites.


Title: Masked Completion via Structured Diffusion with White-Box Transformers
Modern learning frameworks often train deep neural networks with massive amounts of unlabeled data to learn representations by solving simple pretext tasks, then use the representations as foundations for downstream tasks. These networks are empirically designed; as such, they are usually not interpretable, their representations are not structured, and their designs are potentially redundant. White-box deep networks, in which each layer explicitly identifies and transforms structures in the data, present a promising alternative. However, existing white-box architectures have only been shown to work at scale in supervised settings with labeled data, such as classification. In this work, we provide the first instantiation of the white-box design paradigm that can be applied to large-scale unsupervised representation learning. We do this by exploiting a fundamental connection between diffusion, compression, and (masked) completion, deriving a deep transformer-like masked autoencoder architecture, called CRATE-MAE, in which the role of each layer is mathematically fully interpretable: they transform the data distribution to and from a structured representation. Extensive empirical evaluations confirm our analytical insights. CRATE-MAE demonstrates highly promising performance on large-scale imagery datasets while using only ~30% of the parameters compared to the standard masked autoencoder with the same model configuration. The representations learned by CRATE-MAE have explicit structure and also contain semantic meaning. Code is available at https://github.com/Ma-Lab-Berkeley/CRATE  more » « less
Award ID(s):
2031899
PAR ID:
10540485
Author(s) / Creator(s):
; ; ; ;
Publisher / Repository:
International Conference on Learning Representations
Date Published:
Format(s):
Medium: X
Sponsoring Org:
National Science Foundation
More Like this
  1. In this paper, we contend that the objective of representation learning is to compress and transform the distribution of the data, say sets of tokens, towards a mixture of low-dimensional Gaussian distributions supported on incoherent subspaces. The quality of the final representation can be measured by a unified objective function called sparse rate reduction. From this perspective, popular deep networks such as transformers can be naturally viewed as realizing iterative schemes to optimize this objective incrementally. Particularly, we show that the standard transformer block can be derived from alternating optimization on complementary parts of this objective: the multi-head self-attention operator can be viewed as a gradient descent step to compress the token sets by minimizing their lossy coding rate, and the subsequent multi-layer perceptron can be viewed as attempting to sparsify the representation of the tokens. This leads to a family of white-box transformer-like deep network architectures which are mathematically fully interpretable. Despite their simplicity, experiments show that these networks indeed learn to optimize the designed objective: they compress and sparsify representations of large-scale real-world vision datasets such as ImageNet, and achieve performance very close to thoroughly engineered transformers such as ViT. Code is at https://github. com/Ma-Lab-Berkeley/CRATE. 
    more » « less
  2. Self-supervised learning with masked autoencoders has recently gained popularity for its ability to produce effective image or textual representations, which can be applied to various downstream tasks without retraining. However, we observe that the current masked autoencoder models lack good generalization ability on graph data. To tackle this issue, we propose a novel graph masked autoencoder framework called GiGaMAE. Different from existing masked autoencoders that learn node presentations by explicitly reconstructing the original graph components (e.g., features or edges), in this paper, we propose to collaboratively reconstruct informative and integrated latent embeddings. By considering embeddings encompassing graph topology and attribute information as reconstruction targets, our model could capture more generalized and comprehensive knowledge. Furthermore, we introduce a mutual information based reconstruction loss that enables the effective reconstruction of multiple targets. This learning objective allows us to differentiate between the exclusive knowledge learned from a single target and common knowledge shared by multiple targets. We evaluate our method on three downstream tasks with seven datasets as benchmarks. Extensive experiments demonstrate the superiority of GiGaMAE against state-of-the-art baselines. We hope our results will shed light on the design of foundation models on graph-structured data. Our code is available at: https://github.com/sycny/GiGaMAE. 
    more » « less
  3. We propose split-brain autoencoders, a straightforward modification of the traditional autoencoder architecture, for unsupervised representation learning. The method adds a split to the network, resulting in two disjoint sub-networks. Each sub-network is trained to perform a difficult task -- predicting one subset of the data channels from another. Together, the sub-networks extract features from the entire input signal. By forcing the network to solve cross-channel prediction tasks, we induce a representation within the network which transfers well to other, unseen tasks. This method achieves state-of-the-art performance on several large-scale transfer learning benchmarks. 
    more » « less
  4. Deep reinforcement learning (RL) has led to encouraging successes in many challenging control tasks. However, a deep RL model lacks interpretability due to the difficulty of identifying how the model's control logic relates to its network structure. Programmatic policies structured in more interpretable representations emerge as a promising solution. Yet two shortcomings remain: First, synthesizing programmatic policies requires optimizing over the discrete and non-differentiable search space of program architectures. Previous works are suboptimal because they only enumerate program architectures greedily guided by a pretrained RL oracle. Second, these works do not exploit compositionality, an important programming concept, to reuse and compose primitive functions to form a complex function for new tasks. Our first contribution is a programmatically interpretable RL framework that conducts program architecture search on top of a continuous relaxation of the architecture space defined by programming language grammar rules. Our algorithm allows policy architectures to be learned with policy parameters via bilevel optimization using efficient policy-gradient methods, and thus does not require a pretrained oracle. Our second contribution is improving programmatic policies to support compositionality by integrating primitive functions learned to grasp task-agnostic skills as a composite program to solve novel RL problems. Experiment results demonstrate that our algorithm excels in discovering optimal programmatic policies that are highly interpretable. 
    more » « less
  5. The development and application of deep learning method- ologies has grown within educational contexts in recent years. Perhaps attributable, in part, to the large amount of data that is made avail- able through the adoption of computer-based learning systems in class- rooms and larger-scale MOOC platforms, many educational researchers are leveraging a wide range of emerging deep learning approaches to study learning and student behavior in various capacities. Variations of recurrent neural networks, for example, have been used to not only pre- dict learning outcomes but also to study sequential and temporal trends in student data; it is commonly believed that they are able to learn high- dimensional representations of learning and behavioral constructs over time, such as the evolution of a students’ knowledge state while working through assigned content. Recent works, however, have started to dis- pute this belief, instead finding that it may be the model’s complexity that leads to improved performance in many prediction tasks and that these methods may not inherently learn these temporal representations through model training. In this work, we explore these claims further in the context of detectors of student affect as well as expanding on exist- ing work that explored benchmarks in knowledge tracing. Specifically, we observe how well trained models perform compared to deep learning networks where training is applied only to the output layer. While the highest results of prior works utilizing trained recurrent models are found to be superior, the application of our untrained-versions perform compa- rably well, outperforming even previous non-deep learning approaches. 
    more » « less