skip to main content


Title: Connecting Interpretability and Robustness in Decision Trees through Separation
Recent research has recognized interpretability and robustness as essential properties of trustworthy classification. Curiously, a connection between robustness and interpretability was empirically observed, but the theoretical reasoning behind it remained elusive. In this paper, we rigorously investigate this connection. Specifically, we focus on interpretation using decision trees and robustness to l1-perturbation. Previous works defined the notion of r-separation as a sufficient condition for robustness. We prove upper and lower bounds on the tree size in case the data is r-separated. We then show that a tighter bound on the size is possible when the data is linearly separated. We provide the first algorithm with provable guarantees both on robustness, interpretability, and accuracy in the context of decision trees. Experiments confirm that our algorithm yields classifiers that are both interpretable and robust and have high accuracy.  more » « less
Award ID(s):
1804829
NSF-PAR ID:
10282045
Author(s) / Creator(s):
; ;
Date Published:
Journal Name:
Proceedings of Machine Learning Research
Issue:
139
ISSN:
2640-3498
Format(s):
Medium: X
Sponsoring Org:
National Science Foundation
More Like this
  1. Decision Forests are popular machine learning techniques that assist scientists to extract knowledge from massive data sets. This class of tool remains popular because of their interpretability and ease of use, unlike other modern machine learning methods, such as kernel machines and deep learning. Decision forests also scale well for use with large data because training and run time operations are trivially parallelizable allowing for high inference throughputs. A negative aspect of these forests, and an untenable property for many real time applications, is their high inference latency caused by the combination of large model sizes with random memory access patterns. We present memory packing techniques and a novel tree traversal method to overcome this deficiency. The result of our system is a grouping of trees into a hierarchical structure. At low levels, we pack the nodes of multiple trees into contiguous memory blocks so that each memory access fetches data for multiple trees. At higher levels, we use leaf cardinality to identify the most popular paths through a tree and collocate those paths in contiguous cache lines. We extend this layout with a re-ordering of the tree traversal algorithm to take advantage of the increased memory throughput provided by out-of-order execution and cache-line prefetching. Together, these optimizations increase the performance and parallel scalability of classification in ensembles by a factor of ten over an optimized C++ implementation and a popular R-language implementation. 
    more » « less
  2. Tree-based models such as decision trees and random forests (RF) are a cornerstone of modern machine-learning practice. To mitigate overfitting, trees are typically regularized by a variety of techniques that modify their structure (e.g. pruning). We introduce Hierarchical Shrinkage (HS), a post-hoc algorithm that does not modify the tree structure, and instead regularizes the tree by shrinking the prediction over each node towards the sample means of its ancestors. The amount of shrinkage is controlled by a single regularization parameter and the number of data points in each ancestor. Since HS is a post-hoc method, it is extremely fast, compatible with any tree growing algorithm, and can be used synergistically with other regularization techniques. Extensive experiments over a wide variety of real world datasets show that HS substantially increases the predictive performance of decision trees, even when used in conjunction with other regularization techniques. Moreover, we find that applying HS to each tree in an RF often improves accuracy, as well as its interpretability by simplifying and stabilizing its decision boundaries and SHAP values. We further explain the success of HS in improving prediction performance by showing its equivalence to ridge regression on a (supervised) basis constructed of decision stumps associated with the internal nodes of a tree. All code and models are released in a full fledged package available on Github. 
    more » « less
  3. Decision trees have been a very popular class of predictive models for decades due to their interpretability and good performance on categorical features. However, they are not always robust and tend to overfit the data. Additionally, if allowed to grow large, they lose interpretability. In this paper, we present a mixed integer programming formulation to construct optimal decision trees of a prespecified size. We take the special structure of categorical features into account and allow combinatorial decisions (based on subsets of values of features) at each node. Our approach can also handle numerical features via thresholding. We show that very good accuracy can be achieved with small trees using moderately-sized training sets. The optimization problems we solve are tractable with modern solvers. 
    more » « less
  4. This work presents SeizFt—a novel seizure detection framework that utilizes machine learning to automatically detect seizures using wearable SensorDot EEG data. Inspired by interpretable sleep staging, our novel approach employs a unique combination of data augmentation, meaningful feature extraction, and an ensemble of decision trees to improve resilience to variations in EEG and to increase the capacity to generalize to unseen data. Fourier Transform (FT) Surrogates were utilized to increase sample size and improve the class balance between labeled non-seizure and seizure epochs. To enhance model stability and accuracy, SeizFt utilizes an ensemble of decision trees through the CatBoost classifier to classify each second of EEG recording as seizure or non-seizure. The SeizIt1 dataset was used for training, and the SeizIt2 dataset for validation and testing. Model performance for seizure detection was evaluated using two primary metrics: sensitivity using the any-overlap method (OVLP) and False Alarm (FA) rate using epoch-based scoring (EPOCH). Notably, SeizFt placed first among an array of state-of-the-art seizure detection algorithms as part of the Seizure Detection Grand Challenge at the 2023 International Conference on Acoustics, Speech, and Signal Processing (ICASSP). SeizFt outperformed state-of-the-art black-box models in accurate seizure detection and minimized false alarms, obtaining a total score of 40.15, combining OVLP and EPOCH across two tasks and representing an improvement of ~30% from the next best approach. The interpretability of SeizFt is a key advantage, as it fosters trust and accountability among healthcare professionals. The most predictive seizure detection features extracted from SeizFt were: delta wave, interquartile range, standard deviation, total absolute power, theta wave, the ratio of delta to theta, binned entropy, Hjorth complexity, delta + theta, and Higuchi fractal dimension. In conclusion, the successful application of SeizFt to wearable SensorDot data suggests its potential for real-time, continuous monitoring to improve personalized medicine for epilepsy.

     
    more » « less
  5. Abstract Background

    Advanced machine learning models have received wide attention in assisting medical decision making due to the greater accuracy they can achieve. However, their limited interpretability imposes barriers for practitioners to adopt them. Recent advancements in interpretable machine learning tools allow us to look inside the black box of advanced prediction methods to extract interpretable models while maintaining similar prediction accuracy, but few studies have investigated the specific hospital readmission prediction problem with this spirit.

    Methods

    Our goal is to develop a machine-learning (ML) algorithm that can predict 30- and 90- day hospital readmissions as accurately as black box algorithms while providing medically interpretable insights into readmission risk factors. Leveraging a state-of-art interpretable ML model, we use a two-step Extracted Regression Tree approach to achieve this goal. In the first step, we train a black box prediction algorithm. In the second step, we extract a regression tree from the output of the black box algorithm that allows direct interpretation of medically relevant risk factors. We use data from a large teaching hospital in Asia to learn the ML model and verify our two-step approach.

    Results

    The two-step method can obtain similar prediction performance as the best black box model, such as Neural Networks, measured by three metrics: accuracy, the Area Under the Curve (AUC) and the Area Under the Precision-Recall Curve (AUPRC), while maintaining interpretability. Further, to examine whether the prediction results match the known medical insights (i.e., the model is truly interpretable and produces reasonable results), we show that key readmission risk factors extracted by the two-step approach are consistent with those found in the medical literature.

    Conclusions

    The proposed two-step approach yields meaningful prediction results that are both accurate and interpretable. This study suggests a viable means to improve the trust of machine learning based models in clinical practice for predicting readmissions through the two-step approach.

     
    more » « less