skip to main content


Title: Adaptive wavelet distillation from neural networks through interpretations
Recent deep-learning models have achieved impressive prediction performance, but often sacrifice interpretability and computational efficiency. Interpretability is crucial in many disciplines, such as science and medicine, where models must be carefully vetted or where interpretation is the goal itself. Moreover, interpretable models are concise and often yield computational efficiency. Here, we propose adaptive wavelet distillation (AWD), a method which aims to distill information from a trained neural network into a wavelet transform. Specifically, AWD penalizes feature attributions of a neural network in the wavelet domain to learn an effective multi-resolution wavelet transform. The resulting model is highly predictive, concise, computationally efficient, and has properties (such as a multi-scale structure) which make it easy to interpret. In close collaboration with domain experts, we showcase how AWD addresses challenges in two real-world settings: cosmological parameter inference and molecular-partner prediction. In both cases, AWD yields a scientifically interpretable and concise model which gives predictive performance better than state-of-the-art neural networks. Moreover, AWD identifies predictive features that are scientifically meaningful in the context of respective domains. All code and models are released in a full-fledged package available on Github.  more » « less
Award ID(s):
2023505 2031883 1740855 1741340
NSF-PAR ID:
10343658
Author(s) / Creator(s):
; ; ; ;
Date Published:
Journal Name:
Advances in neural information processing systems
ISSN:
1049-5258
Format(s):
Medium: X
Sponsoring Org:
National Science Foundation
More Like this
  1. Electronic health records (EHRs) have been heavily used in modern healthcare systems for recording patients' admission information to health facilities. Many data-driven approaches employ temporal features in EHR for predicting specific diseases, readmission times, and diagnoses of patients. However, most existing predictive models cannot fully utilize EHR data, due to an inherent lack of labels in supervised training for some temporal events. Moreover, it is hard for the existing methods to simultaneously provide generic and personalized interpretability. To address these challenges, we propose Sherbet, a self-supervised graph learning framework with hyperbolic embeddings for temporal health event prediction. We first propose a hyperbolic embedding method with information flow to pretrain medical code representations in a hierarchical structure. We incorporate these pretrained representations into a graph neural network (GNN) to detect disease complications and design a multilevel attention method to compute the contributions of particular diseases and admissions, thus enhancing personalized interpretability. We present a new hierarchy-enhanced historical prediction proxy task in our self-supervised learning framework to fully utilize EHR data and exploit medical domain knowledge. We conduct a comprehensive set of experiments on widely used publicly available EHR datasets to verify the effectiveness of our model. Our results demonstrate the proposed model's strengths in both predictive tasks and interpretable abilities. 
    more » « less
  2. Machine learning methods, particularly neural networks trained on large datasets, are transforming how scientists approach scientific discovery and experimental design. However, current state-of-the-art neural networks are limited by their uninterpretability: Despite their excellent accuracy, they cannot describe how they arrived at their predictions. Here, using an “interpretable-by-design” approach, we present a neural network model that provides insights into RNA splicing, a fundamental process in the transfer of genomic information into functional biochemical products. Although we designed our model to emphasize interpretability, its predictive accuracy is on par with state-of-the-art models. To demonstrate the model’s interpretability, we introduce a visualization that, for any given exon, allows us to trace and quantify the entire decision process from input sequence to output splicing prediction. Importantly, the model revealed uncharacterized components of the splicing logic, which we experimentally validated. This study highlights how interpretable machine learning can advance scientific discovery. 
    more » « less
  3. Knowledge tracing (KT) refers to the problem of predicting future learner performance given their past performance in educational applications. Recent developments in KT using flexible deep neural network-based models excel at this task. However, these models often offer limited interpretability, thus making them insufficient for personalized learning, which requires using interpretable feedback and actionable recommendations to help learners achieve better learning outcomes. In this paper, we propose attentive knowledge tracing (AKT), which couples flexible attention-based neural network models with a series of novel, interpretable model components inspired by cognitive and psychometric models. AKT uses a novel monotonic attention mechanism that relates a learner’s future responses to assessment questions to their past responses; attention weights are computed using exponential decay and a context-aware relative distance measure, in addition to the similarity between questions. Moreover, we use the Rasch model to regularize the concept and question embeddings; these embeddings are able to capture individual differences among questions on the same concept without using an excessive number of parameters. We conduct experiments on several real-world benchmark datasets and show that AKT outperforms existing KT methods (by up to 6% in AUC in some cases) on predicting future learner responses. We also conduct several case studies and show that AKT exhibits excellent interpretability and thus has potential for automated feedback and personalization in real-world educational settings. 
    more » « less
  4. Modern data acquisition routinely produce massive amounts of event sequence data in various domains, such as social media, healthcare, and financial markets. These data often ex- hibit complicated short-term and long-term temporal dependencies. However, most of the ex- isting recurrent neural network-based point process models fail to capture such dependencies, and yield unreliable prediction performance. To address this issue, we propose a Transformer Hawkes Process (THP) model, which leverages the self-attention mechanism to capture long- term dependencies and meanwhile enjoys computational efficiency. Numerical experiments on various datasets show that THP outperforms existing models in terms of both likelihood and event prediction accuracy by a notable margin. Moreover, THP is quite general and can incorpo- rate additional structural knowledge. We provide a concrete example, where THP achieves im- proved prediction performance for learning multiple point processes when incorporating their relational information. 
    more » « less
  5. Abstract The Coronavirus Disease 2019 (COVID-19) has had a profound impact on global health and economy, making it crucial to build accurate and interpretable data-driven predictive models for COVID-19 cases to improve public policy making. The extremely large scale of the pandemic and the intrinsically changing transmission characteristics pose a great challenge for effectively predicting COVID-19 cases. To address this challenge, we propose a novel hybrid model in which the interpretability of the Autoregressive model (AR) and the predictive power of the long short-term memory neural networks (LSTM) join forces. The proposed hybrid model is formalized as a neural network with an architecture that connects two composing model blocks, of which the relative contribution is decided data-adaptively in the training procedure. We demonstrate the favorable performance of the hybrid model over its two single composing models as well as other popular predictive models through comprehensive numerical studies on two data sources under multiple evaluation metrics. Specifically, in county-level data of 8 California counties, our hybrid model achieves 4.173% MAPE, outperforming the composing AR (5.629%) and LSTM (4.934%) alone on average. In country-level datasets, our hybrid model outperforms the widely-used predictive models such as AR, LSTM, Support Vector Machines, Gradient Boosting, and Random Forest, in predicting the COVID-19 cases in Japan, Canada, Brazil, Argentina, Singapore, Italy, and the United Kingdom. In addition to the predictive performance, we illustrate the interpretability of our proposed hybrid model using the estimated AR component, which is a key feature that is not shared by most black-box predictive models for COVID-19 cases. Our study provides a new and promising direction for building effective and interpretable data-driven models for COVID-19 cases, which could have significant implications for public health policy making and control of the current COVID-19 and potential future pandemics. 
    more » « less