skip to main content


Title: Interpreting and Improving Deep-Learning Models with Reality Checks
Recent deep-learning models have achieved impressive predictive performance by learning complex functions of many variables, often at the cost of interpretability. This chapter covers recent work aiming to interpret models by attributing importance to features and feature groups for a single prediction. Importantly, the proposed attributions assign importance to interactions between features, in addition to features in isolation. These attributions are shown to yield insights across real-world domains, including bio-imaging, cosmology image and natural-language processing. We then show how these attributions can be used to directly improve the generalization of a neural network or to distill it into a simple model. Throughout the chapter, we emphasize the use of reality checks to scrutinize the proposed interpretation techniques. (Code for all methods in this chapter is available at github.com/csinva and github.com/Yu-Group, implemented in PyTorch [54]).  more » « less
Award ID(s):
2023505
NSF-PAR ID:
10343666
Author(s) / Creator(s):
; ;
Date Published:
Journal Name:
Lecture notes in computer science
ISSN:
0302-9743
Format(s):
Medium: X
Sponsoring Org:
National Science Foundation
More Like this
  1. Recent deep-learning models have achieved impressive prediction performance, but often sacrifice interpretability and computational efficiency. Interpretability is crucial in many disciplines, such as science and medicine, where models must be carefully vetted or where interpretation is the goal itself. Moreover, interpretable models are concise and often yield computational efficiency. Here, we propose adaptive wavelet distillation (AWD), a method which aims to distill information from a trained neural network into a wavelet transform. Specifically, AWD penalizes feature attributions of a neural network in the wavelet domain to learn an effective multi-resolution wavelet transform. The resulting model is highly predictive, concise, computationally efficient, and has properties (such as a multi-scale structure) which make it easy to interpret. In close collaboration with domain experts, we showcase how AWD addresses challenges in two real-world settings: cosmological parameter inference and molecular-partner prediction. In both cases, AWD yields a scientifically interpretable and concise model which gives predictive performance better than state-of-the-art neural networks. Moreover, AWD identifies predictive features that are scientifically meaningful in the context of respective domains. All code and models are released in a full-fledged package available on Github. 
    more » « less
  2. In recent years, methods were proposed for assigning feature importance scores to measure the contribution of individual features. While in some cases the goal is to understand a specific model, in many cases the goal is to understand the contribution of certain properties (features) to a real-world phenomenon. Thus, a distinction has been made between feature importance scores that explain a model and scores that explain the data. When explaining the data, machine learning models are used as proxies in settings where conducting many real-world experiments is expensive or prohibited. While existing feature importance scores show great success in explaining models, we demonstrate their limitations when explaining the data, especially in the presence of correlations between features. Therefore, we develop a set of axioms to capture properties expected from a feature importance score when explaining data and prove that there exists only one score that satisfies all of them, the Marginal Contribution Feature Importance (MCI). We analyze the theoretical properties of this score function and demonstrate its merits empirically. 
    more » « less
  3. 1-parameter persistent homology, a cornerstone in Topological Data Analysis (TDA), studies the evolution of topological features such as connected components and cycles hidden in data. It has been applied to enhance the representation power of deep learning models, such as Graph Neural Networks (GNNs). To enrich the representations of topological features, here we propose to study 2-parameter persistence modules induced by bi-filtration functions. In order to incorporate these representations into machine learning models, we introduce a novel vector representation called Generalized Rank Invariant Landscape (GRIL) for 2-parameter persistence modules. We show that this vector representation is 1-Lipschitz stable and differentiable with respect to underlying filtration functions and can be easily integrated into machine learning models to augment encoding topological features. We present an algorithm to compute the vector representation efficiently. We also test our methods on synthetic and benchmark graph datasets, and compare the results with previous vector representations of 1-parameter and 2-parameter persistence modules. Further, we augment GNNs with GRIL features and observe an increase in performance indicating that GRIL can capture additional features enriching GNNs. We make the complete code for the proposed method available at https://github.com/soham0209/mpml-graph. 
    more » « less
  4. Kuijjer, Marieke (Ed.)
    Abstract Motivation Biological processes are regulated by underlying genes and their interactions that form gene regulatory networks (GRNs). Dysregulation of these GRNs can cause complex diseases such as cancer, Alzheimer’s and diabetes. Hence, accurate GRN inference is critical for elucidating gene function, allowing for the faster identification and prioritization of candidate genes for functional investigation. Several statistical and machine learning-based methods have been developed to infer GRNs based on biological and synthetic datasets. Here, we developed a method named AGRN that infers GRNs by employing an ensemble of machine learning algorithms. Results From the idea that a single method may not perform well on all datasets, we calculate the gene importance scores using three machine learning methods—random forest, extra tree and support vector regressors. We calculate the importance scores from Shapley Additive Explanations, a recently published method to explain machine learning models. We have found that the importance scores from Shapley values perform better than the traditional importance scoring methods based on almost all the benchmark datasets. We have analyzed the performance of AGRN using the datasets from the DREAM4 and DREAM5 challenges for GRN inference. The proposed method, AGRN—an ensemble machine learning method with Shapley values, outperforms the existing methods both in the DREAM4 and DREAM5 datasets. With improved accuracy, we believe that AGRN inferred GRNs would enhance our mechanistic understanding of biological processes in health and disease. Availabilityand implementation https://github.com/DuaaAlawad/AGRN. Supplementary information Supplementary data are available at Bioinformatics online. 
    more » « less
  5. Graph Neural Networks have recently become a prevailing paradigm for various high-impact graph analytical problems. Existing efforts can be mainly categorized as spectral-based and spatial-based methods. The major challenge for the former is to find an appropriate graph filter to distill discriminative information from input signals for learning. Recently, myriads of explorations are made to achieve better graph filters, e.g., Graph Convolutional Network (GCN), which leverages Chebyshev polynomial truncation to seek an approximation of graph filters and bridge these two families of methods. Nevertheless, it has been shown in recent studies that GCN and its variants are essentially employing fixed low-pass filters to perform information denoising. Thus their learning capability is rather limited and may over-smooth node representations at deeper layers. To tackle these problems, we develop a novel graph neural network framework AdaGNN with a well-designed adaptive frequency response filter. At its core, AdaGNN leverages a simple but elegant trainable filter that spans across multiple layers to capture the varying importance of different frequency components for node representation learning. The inherent differences among different feature channels are also well captured by the filter. As such, it empowers AdaGNN with stronger expressiveness and naturally alleviates the over-smoothing problem. We empirically validate the effectiveness of the proposed framework on various benchmark datasets. Theoretical analysis is also provided to show the superiority of the proposed AdaGNN. The open-source implementation of AdaGNN can be found here: https://github.com/yushundong/AdaGNN. 
    more » « less