FIGS: Fast Interpretable Greedy-Tree Sums


Click and drag for more FIGS

📄 Paper, 🗂 Doc, 📌 Citation

Yan Shuo Tan*, Chandan Singh*, Keyan Nasseri*, Abhineet Agarwal*, James Duncan, Omer Ronen, Matthew Epland, Aaron Kornblith, Bin Yu

Modern machine learning has achieved impressive prediction performance, but often sacrifices interpretability, a critical consideration in many problems. Here, we propose Fast Interpretable Greedy-Tree Sums (FIGS), an algorithm for fitting concise rule-based models. Specifically, FIGS generalizes the CART algorithm to work on sums of trees, growing a flexible number of them simultaneously. The total number of splits across all the trees is restricted by a pre-specified threshold, which ensures that FIGS remains interpretable. Extensive experiments show that FIGS achieves state-of-the-art performance across a wide array of real-world datasets when restricted to very few splits (e.g. less than 20). Theoretical and simulation results suggest that FIGS overcomes a key weakness of single-tree models by disentangling additive components of generative additive models, thereby significantly improving convergence rates for l2 generalization error. We further characterize the success of FIGS by quantifying how it reduces repeated splits, which can lead to redundancy in single-tree models such as CART. All code and models are released in a full-fledged package available on Github.

How does FIGS work?

Intuitively, FIGS works by extending CART, a typical greedy algorithm for growing a decision tree, to consider growing a sum of trees simultaneously (see Fig 1). At each iteration, FIGS may grow any existing tree it has already started or start a new tree; it greedily selects whichever rule reduces the total unexplained variance (or an alternative splitting criterion) the most. To keep the trees in sync with one another, each tree is made to predict the residuals remaining after summing the predictions of all other trees.

FIGS is intuitively similar to ensemble approaches such as gradient boosting / random forest, but importantly since all trees are grown to compete with each other the model can adapt more to the underlying structure in the data. The number of trees and size/shape of each tree emerge automatically from the data rather than being manually specified.


Fig 1. High-level intuition for how FIGS fits a model.

An example using FIGS

FIGS can be used in the same way as standard scikit-learn models: simply import a classifier or regressor and use the fit and predict methods. Here's a full example of using it on a sample clinical dataset.
        
    from imodels import FIGSClassifier, get_clean_dataset
    from sklearn.model_selection import train_test_split

    # prepare data (in this a sample clinical dataset)
    X, y, feat_names = get_clean_dataset('csi_pecarn_pred')
    X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.33, random_state=42)

    # fit the model
    model = FIGSClassifier(max_rules=4)  # initialize a model
    model.fit(X_train, y_train)   # fit model
    preds = model.predict(X_test) # discrete predictions: shape is (n_test, 1)
    preds_proba = model.predict_proba(X_test) # predicted probabilities: shape is (n_test, n_classes)

    # visualize the model
    model.plot(feature_names=feat_names, filename='out.svg', dpi=300)
            
    
This results in a simple model -- it contains only 4 splits (since we specified that the model should have no more than 4 splits (max_rules=4). Predictions are made by summing the value obtained from the appropriate leaf of each tree. This model is extremely interpretable, as a physician can now (i) easily make predictions using the 4 relevant features and (ii) vet the model to ensure it matches their domain expertise. Note that this model is just for illustration purposes, and achieves ~84% accuracy.


Fig 2. Simple model learned by FIGS for predicting risk of cervical spinal injury.

If we want a more flexible model, we can also remove the constraint on the number of rules (changing the code to model = FIGSClassifier()), resulting in a larger model (see Fig 3). Note that the number of trees and how balanced they are emerges from the structure of the data -- only the total number of rules may be specified.


Fig 3. Slightly larger model learned by FIGS for predicting risk of cervical spinal injury.

Another example of using FIGS

Here, we examine the Diabetes classification dataset, in which eight risk factors were collected and used to predict the onset of diabetes within 5 five years. Fitting, several models we find that with very few rules, the model can achieve excellent test performance.

For example, Fig 2 shows a model fitted using the FIGS algorithm which achieves a test-AUC of 0.820 despite being extremely simple. In this model, each feature contributes independently of the others, and the final risks from each of three key features is summed to get a risk for the onset of diabetes (higher is higher risk). As opposed to a black-box model, this model is easy to interpret, fast to compute with, and allows us to vet the features being used for decision-making.


Fig 2. Simple model learned by FIGS for diabetes risk prediction.