Hierarchical shrinkage: improving the accuracy and interpretability of tree-based methods
Abhineet Agarwal*, Yan Shuo Tan*, Omer Ronen, Chandan Singh, Bin Yu
📄 Paper (ICML 2022), 🗂 Doc, 📌 Citation
Hierarchical shrinkage is an extremely fast post-hoc regularization method which works on any decision tree (or tree-based ensemble, such as Random Forest). It does not modify the tree structure, and instead regularizes the tree by shrinking the prediction over each node towards the sample means of its ancestors (using a single regularization parameter). Experiments over a wide variety of datasets show that hierarchical shrinkage substantially increases the predictive performance of individual decision trees and decision-tree ensembles.How does Hierarchical shrinkage work?
Fig 1. HS applies post-hoc regularization to any decision tree by shrinking each node towards its parent. This is done after a tree has been trained. The amount of shrinkage can be varied using a regularization param (this works best if the parameter is chosen via cross-validation).
An example using HS
HS can be used in the same way as standard scikit-learn models: simply import a classifier or regressor and use thefit
and predict
methods.
Here's a full example of using it on a sample clinical dataset.
from imodels import HSTreeClassifierCV, get_clean_dataset
from sklearn.model_selection import train_test_split
from sklearn.tree import plot_tree
# prepare data (in this a sample clinical dataset)
X, y, feat_names = get_clean_dataset('csi_pecarn_pred')
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.33, random_state=42)
# fit the model
model = HSTreeClassifierCV(max_leaf_nodes=7) # initialize a model
model.fit(X_train, y_train) # fit model
preds = model.predict(X_test) # discrete predictions: shape is (n_test, 1)
preds_proba = model.predict_proba(X_test) # predicted probabilities: shape is (n_test, n_classes)
# visualize the model
plot_tree(model.estimator_, feature_names=feat_names)
Here we used HSTreeClassifierCV
, which selects the amount of regularization to use via
cross-validation, but we can also use HSTreeClassifier
if we want to specify a
particular amount of regularization.
For regression, we can use the corresponding classes: HSTreeRegressorCV
and
HSTreeRegressor
.
Fig 2. Simple model learned by HS for predicting risk of cervical spinal injury.
Examples with HS on synthetic data
See some examples of how hierarchical shrinkage works on one-dimensional functions which are fitted with a CART decision tree.
Fig 3. Step function.
Fig 4. Linear function.
Applying HS to tree ensembles
HS can also be used on tree ensembles to regularize each tree in an ensemble (e.g. in a Random Forest). We must simply pass the desired estimator during initialization.
from sklearn.ensemble import RandomForestClassifier # also works with ExtraTreesClassifier, GradientBoostingClassifier
from imodels import HSTreeClassifier
ensemble = RandomForestClassifier()
model = HSTreeClassifier(estimator_=ensemble)
model = model.fit(X_train, y_train)