MDI+: A Flexible Feature Importance Framework for Random Forests
Abhineet Agarwal*, Ana M. Kenney*, Yan Shuo Tan*, Tiffany Tang*, Bin Yu
MDI+ is a novel feature importance framework, which generalizes the popular mean decrease in impurity (MDI) importance score for random forests. At its core, MDI+ expands upon a recently discovered connection between linear regression and decision trees. In doing so, MDI+ enables practitioners to (1) tailor the feature importance computation to the data/problem structure and (2) incorporate additional features or knowledge to mitigate known biases of decision trees. In both real data case studies and extensive real-data-inspired simulations, MDI+ outperforms commonly used feature importance measures (e.g., MDI, permutation-based scores, and TreeSHAP) by substantial margins.
For further details, we refer to Agarwal et al. (2023).
Regression Example Usage:
from imodels.importance import RandomForestPlusRegressor
rf_plus_model = RandomForestPlusRegressor()
rf_plus_model.fit(X, y)
mdi_plus_scores = rf_plus_model.get_mdi_plus_scores(X, y)
Classification Example Usage:
from imodels.importance import RandomForestPlusClassifier
rf_plus_model = RandomForestPlusClassifier()
rf_plus_model.fit(X, y)
mdi_plus_scores = rf_plus_model.get_mdi_plus_scores(X, y)
Demo notebooks
MDI+ demo
- Shows how to compute MDI+ importance scores for different tasks (regression and classification) and configurations (with flexible GLMs, scoring metrics, and custom transformations).
- Provides starter code on how to choose the GLM and scoring metric within MDI+ via a stability metric and/or combine these fits in an ensemble.
Overview of MDI+
Input: a collection of fitted trees (e.g., a random forest), data X
and y
For each fitted tree in the forest:
- Transform
X
using the learnt "stump" features from the fitted tree, and append any additional engineered features (e.g., the rawX
features) to this transformed dataset. - Fit a prediction model on this augmented transformed dataset to predict
y
. Here, we recommend fitting a generalized linear model (GLM) to leverage computational speed-ups. - Use this fitted prediction model to make partial model predictions for each feature k. That is, for each feature k, get the model's predictions if the contribution of all other features (except feature k) were zeroed out. Put differently, the kth partial model predictions are the predictions we get when using only the engineered features that are related to feature k.
- For each feature k, evaluate the similarity between the observed
y
and the kth partial model predictions using any user-defined similarity metric (i.e., a larger value should indicate greater feature importance).
This gives the MDI+ scores for a single tree. To get the MDI+ scores for the forest, these scores are averaged across all trees in the forest.
Practical Considerations
We show in Agarwal et al. (2023) that this framework is indeed a proper generalization of the popular MDI feature importance score. However, as a result of the increased flexibility provided by MDI+, there are several choices that must be made by the analyst to run MDI+ in practice. In particular,
- In Step 1: What feature engineering/transformations to include?
- We recommend including the raw feature (i.e.,
X
) in this transformed dataset. This is done by default viaRandomForestPlus*(include_raw=True)
. To include additional transformations, create customBlockTransformerBase
object(s) and use theadd_transformers
argument inRandomForestPlus*()
.
- We recommend including the raw feature (i.e.,
- In Step 2: Which GLM?
- We recommend using
RidgeRegressorPPM()
for regression tasks andLogisticClassifierPPM()
for classification tasks and thus set these to be the defaults. To use a custom prediction model, use theprediction_model
argument inRandomForestPlus*()
.
- We recommend using
- In Step 3: Which sample splitting strategy (if any) to use when making the partial
model predictions?
- We recommend using a leave-one-out (
"loo"
) sample splitting strategy as it overcomes the known correlation and entropy biases suffered by MDI. Out-of-bag ("oob"
) can also be used to overcome these biases but tends to be more unstable than leave-one-out across different random forest fits. MDI uses an in-bag sample splitting scheme and is not recommended. The sample splitting strategy is set to"loo"
by default but can be changed via thesample_split
argument inRandomForestPlus*()
.
- We recommend using a leave-one-out (
- In Step 4: Which similarity metric to use?
- We recommend using r-squared for regression tasks and log-loss for classification
tasks, which are the defaults. To use a custom metric, use the
scoring_fns
argument in theget_mdi_plus_scores()
method.
- We recommend using r-squared for regression tasks and log-loss for classification
tasks, which are the defaults. To use a custom metric, use the
These recommendations are based on extensive simulations across a wide variety of data-generating processes, data sets, noise levels, and misspecifications.
Nevertheless, different choices may be better for different problems. For examples on how to implement some of these custom options, see the MDI+ Demo notebook. This demo also includes examples on how to aggregate feature importances from multiple MDI+ configurations in an ensemble as well as how to choose the "best" GLM and metric in a data-driven manner based upon a stability score.