Hierarchical shrinkage: improving the accuracy and interpretability of tree-based methods
Abhineet Agarwal*, Yan Shuo Tan*, Omer Ronen, Chandan Singh, Bin Yu
How does Hierarchical shrinkage work?
Fig 1. HS applies post-hoc regularization to any decision tree by shrinking each node towards its parent. This is done after a tree has been trained. The amount of shrinkage can be varied using a regularization param (this works best if the parameter is chosen via cross-validation).
An example using HSHS can be used in the same way as standard scikit-learn models: simply import a classifier or regressor and use the
predictmethods. Here's a full example of using it on a sample clinical dataset.
Here we used
from imodels import HSTreeClassifierCV, get_clean_dataset from sklearn.model_selection import train_test_split from sklearn.tree import plot_tree # prepare data (in this a sample clinical dataset) X, y, feat_names = get_clean_dataset('csi_pecarn_pred') X_train, X_test, y_train, y_test = train_test_split( X, y, test_size=0.33, random_state=42) # fit the model model = HSTreeClassifierCV(max_leaf_nodes=7) # initialize a model model.fit(X_train, y_train) # fit model preds = model.predict(X_test) # discrete predictions: shape is (n_test, 1) preds_proba = model.predict_proba(X_test) # predicted probabilities: shape is (n_test, n_classes) # visualize the model plot_tree(model.estimator_, feature_names=feat_names)
HSTreeClassifierCV, which selects the amount of regularization to use via cross-validation, but we can also use
HSTreeClassifierif we want to specify a particular amount of regularization. For regression, we can use the corresponding classes:
Examples with HS on synthetic data
See some examples of how hierarchical shrinkage works on one-dimensional functions which are fitted with a CART decision tree.
Fig 3. Step function.
Fig 4. Linear function.
Applying HS to tree ensemblesHS can also be used on tree ensembles to regularize each tree in an ensemble (e.g. in a Random Forest). We must simply pass the desired estimator during initialization.
from sklearn.ensemble import RandomForestClassifier # also works with ExtraTreesClassifier, GradientBoostingClassifier from imodels import HSTreeClassifier ensemble = RandomForestClassifier() model = HSTreeClassifier(estimator_=ensemble) model = model.fit(X_train, y_train)