Expand source code
import argparse
import itertools
import os
from functools import partial

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib import cm
from tqdm import tqdm
from sklearn import model_selection, datasets
from sklearn.metrics import mean_squared_error

from imodels import get_clean_dataset
from imodels.experimental.bartpy.model import Model
from imodels.experimental.bartpy.tree import Tree
from ..sklearnmodel import BART, SklearnModel

ART_PATH = "/accounts/campus/omer_ronen/projects/tree_shrink/imodels/art"
DATASETS_REGRESSION = [
    # leo-breiman paper random forest uses some UCI datasets as well
    # pg 23: https://www.stat.berkeley.edu/~breiman/randomforest2001.pdf
    ('friedman1', 'friedman1', 'synthetic'),
    ('friedman2', 'friedman2', 'synthetic'),
    ('friedman3', 'friedman3', 'synthetic'),
    ('abalone', '183', 'openml'),
    ("diabetes-regr", "diabetes", 'sklearn'),
    ("california-housing", "california_housing", 'sklearn'),  # this replaced boston-housing due to ethical issues
    ("satellite-image", "294_satellite_image", 'pmlb'),
    ("echo-months", "1199_BNG_echoMonths", 'pmlb'),
    ("breast-tumor", "1201_BNG_breastTumor", 'pmlb'),  # this one is v big (100k examples)

]


def parse_args():
    parser = argparse.ArgumentParser(description='BART Research motivation')
    parser.add_argument('datasets', metavar='datasets', type=str, nargs='+',
                        help='datasets to run sim over')

    args = parser.parse_args()
    return args


def mse_functional(model: SklearnModel, sample: Model, X, y):
    predictions_transformed = sample.predict(X)
    predictions = model.data.y.unnormalize_y(predictions_transformed)
    return mean_squared_error(predictions, y)


def n_leaves_functional(model: SklearnModel, sample: Model, X, y):
    n_leaves = 0
    for tree in sample.trees:
        n_leaves += len(tree.leaf_nodes)
    return n_leaves / len(sample.trees)


def analyze_functional(model: SklearnModel, functional: callable, ax=None, name=None, X=None, y=None):
    if ax is None:
        _, ax = plt.subplots(1, 1)
    n_chains = model.n_chains
    chain_len = int(len(model.model_samples) / n_chains)
    color = iter(cm.rainbow(np.linspace(0, 1, n_chains)))

    functional_specific = partial(functional, X=X, y=y, model=model)

    for c in range(n_chains):
        clr = next(color)
        chain_sample = model.model_samples[c * chain_len:(c + 1) * chain_len]
        chain_functional = [functional_specific(sample=s) for s in chain_sample]
        ax.plot(np.arange(chain_len), chain_functional, color=clr, label=f"Chain {c}")

    ax.set_ylabel(name)
    ax.set_xlabel("Iteration")
    ax.legend()
    # ax.set_title(title)
    return ax


def plot_chains_leaves(model: SklearnModel, ax=None, title="Tree Structure/Prediction Variation", x_label=False, X=None,
                       y=None):
    if ax is None:
        _, ax = plt.subplots(1, 1)
    complexity = {i: [] for i in range(model.n_trees)}
    n_chains = model.n_chains
    for sample in model.model_samples:
        for i, tree in enumerate(sample.trees):
            complexity[i].append(len(tree.leaf_nodes))

    chain_len = int(len(model.model_samples) / n_chains)
    color = iter(cm.rainbow(np.linspace(0, 1, n_chains)))

    for c in range(n_chains):
        clr = next(color)
        chain_preds = model.predict_chain(X, c)
        chain_std = np.round(model.chain_mse_std(X, y, c), 2)
        mse_chain = np.round(mean_squared_error(chain_preds, y), 2)

        trees_chain = np.stack([complexity[t][c * chain_len:(c + 1) * chain_len] for t in range(model.n_trees)], axis=1)
        y_plt = np.mean(trees_chain, axis=1)
        ax.plot(np.arange(chain_len), y_plt, color=clr, label=f"Chain {c} (mse: {mse_chain} std: {chain_std})")

    ax.set_ylabel("# Leaves")
    if x_label:
        ax.set_xlabel("Iteration")
    ax.legend()
    ax.set_title(title)
    return ax


def plot_within_chain(model: SklearnModel, ax=None, title="Within Chain Variation", x_label=False, X=None, y=None):
    if ax is None:
        _, ax = plt.subplots(1, 1)
    n_chains = model.n_chains

    chain_len = int(len(model.model_samples) / n_chains)
    color = iter(cm.rainbow(np.linspace(0, 1, n_chains)))

    for c in range(n_chains):
        clr = next(color)
        chain_preds = model.chain_predictions(X, c)
        mean_pred = np.array(chain_preds).mean(axis=0)

        y_plt = [mean_squared_error(mean_pred, p) for p in chain_preds]
        ax.plot(np.arange(chain_len), y_plt, color=clr, label=f"Chain {c} (Average {np.round(np.mean(y_plt), 2)})")

    ax.set_ylabel("mean squared distance to average iteration")
    if x_label:
        ax.set_xlabel("Iteration")
    ax.legend()
    ax.set_title(title)
    return ax


def plot_across_chains(model: SklearnModel, ax=None, title="Across Chain Variation", x_label=False, X=None, y=None):
    if ax is None:
        _, ax = plt.subplots(1, 1)
    n_chains = model.n_chains

    preds = []
    mat = np.zeros(shape=(n_chains, n_chains))

    for c in range(n_chains):
        preds.append(model.predict_chain(X, c))
    for c_i, c_j in itertools.combinations(range(n_chains), 2):
        mat[c_i, c_j] = mean_squared_error(preds[c_i], preds[c_j])
        mat[c_j, c_i] = mean_squared_error(preds[c_i], preds[c_j])
    ax.matshow(mat)
    for c_i, c_j in itertools.combinations(range(n_chains), 2):
        c = np.round(mat[c_i, c_j], 2)
        ax.text(c_i, c_j, c, va='center', ha='center')
        ax.text(c_j, c_i, c, va='center', ha='center')
    ax.set_xlabel("mean squared distance between predictions")

    ax.legend()
    ax.set_title(f"{title} (Between Chains Var {np.round(model.between_chains_var(X), 2)})")
    return ax


def main():
    n_trees = 50
    n_samples = 5000
    n_burn = 10000
    n_chains = 5
    with tqdm(DATASETS_REGRESSION) as t:
        for d in t:
            t.set_description(f'{d[0]}')
            X, y, feat_names = get_clean_dataset(d[1], data_source=d[2])
            n = len(y)

            X_train, X_test, y_train, y_test = model_selection.train_test_split(
                X, y, test_size=0.3, random_state=4)

            bart_zero = BART(classification=False, store_acceptance_trace=True, n_trees=n_trees, n_samples=n_samples,
                             n_burn=n_burn, n_chains=n_chains, thin=1)
            bart_zero.fit(X_train, y_train)

            fig, axs = plt.subplots(4, 1, figsize=(10, 22))
            # fig.tight_layout()
            fig.subplots_adjust(hspace=.2)

            # plot_chains_leaves(bart_zero, axs[0], X=X_test, y=y_test)
            analyze_functional(bart_zero, functional=mse_functional, ax=axs[0], X=X_test, y=y_test, name="Test MSE")
            analyze_functional(bart_zero, functional=n_leaves_functional, ax=axs[1], X=X_test, y=y_test,
                               name="# Leaves")
            plot_within_chain(bart_zero, axs[2], X=X_test, y=y_test)
            plot_across_chains(bart_zero, axs[3], X=X_test, y=y_test)

            title = f"Dataset: {d[0].capitalize()}, (n = {n}, burn = {n_burn})"
            plt.suptitle(title)

            plt.savefig(os.path.join(ART_PATH, "functional", f"{d[0]}_samples_{n_samples}.png"))
            plt.close()


if __name__ == '__main__':
    main()

Functions

def analyze_functional(model: SklearnModel, functional: , ax=None, name=None, X=None, y=None)
Expand source code
def analyze_functional(model: SklearnModel, functional: callable, ax=None, name=None, X=None, y=None):
    if ax is None:
        _, ax = plt.subplots(1, 1)
    n_chains = model.n_chains
    chain_len = int(len(model.model_samples) / n_chains)
    color = iter(cm.rainbow(np.linspace(0, 1, n_chains)))

    functional_specific = partial(functional, X=X, y=y, model=model)

    for c in range(n_chains):
        clr = next(color)
        chain_sample = model.model_samples[c * chain_len:(c + 1) * chain_len]
        chain_functional = [functional_specific(sample=s) for s in chain_sample]
        ax.plot(np.arange(chain_len), chain_functional, color=clr, label=f"Chain {c}")

    ax.set_ylabel(name)
    ax.set_xlabel("Iteration")
    ax.legend()
    # ax.set_title(title)
    return ax
def main()
Expand source code
def main():
    n_trees = 50
    n_samples = 5000
    n_burn = 10000
    n_chains = 5
    with tqdm(DATASETS_REGRESSION) as t:
        for d in t:
            t.set_description(f'{d[0]}')
            X, y, feat_names = get_clean_dataset(d[1], data_source=d[2])
            n = len(y)

            X_train, X_test, y_train, y_test = model_selection.train_test_split(
                X, y, test_size=0.3, random_state=4)

            bart_zero = BART(classification=False, store_acceptance_trace=True, n_trees=n_trees, n_samples=n_samples,
                             n_burn=n_burn, n_chains=n_chains, thin=1)
            bart_zero.fit(X_train, y_train)

            fig, axs = plt.subplots(4, 1, figsize=(10, 22))
            # fig.tight_layout()
            fig.subplots_adjust(hspace=.2)

            # plot_chains_leaves(bart_zero, axs[0], X=X_test, y=y_test)
            analyze_functional(bart_zero, functional=mse_functional, ax=axs[0], X=X_test, y=y_test, name="Test MSE")
            analyze_functional(bart_zero, functional=n_leaves_functional, ax=axs[1], X=X_test, y=y_test,
                               name="# Leaves")
            plot_within_chain(bart_zero, axs[2], X=X_test, y=y_test)
            plot_across_chains(bart_zero, axs[3], X=X_test, y=y_test)

            title = f"Dataset: {d[0].capitalize()}, (n = {n}, burn = {n_burn})"
            plt.suptitle(title)

            plt.savefig(os.path.join(ART_PATH, "functional", f"{d[0]}_samples_{n_samples}.png"))
            plt.close()
def mse_functional(model: SklearnModel, sample: Model, X, y)
Expand source code
def mse_functional(model: SklearnModel, sample: Model, X, y):
    predictions_transformed = sample.predict(X)
    predictions = model.data.y.unnormalize_y(predictions_transformed)
    return mean_squared_error(predictions, y)
def n_leaves_functional(model: SklearnModel, sample: Model, X, y)
Expand source code
def n_leaves_functional(model: SklearnModel, sample: Model, X, y):
    n_leaves = 0
    for tree in sample.trees:
        n_leaves += len(tree.leaf_nodes)
    return n_leaves / len(sample.trees)
def parse_args()
Expand source code
def parse_args():
    parser = argparse.ArgumentParser(description='BART Research motivation')
    parser.add_argument('datasets', metavar='datasets', type=str, nargs='+',
                        help='datasets to run sim over')

    args = parser.parse_args()
    return args
def plot_across_chains(model: SklearnModel, ax=None, title='Across Chain Variation', x_label=False, X=None, y=None)
Expand source code
def plot_across_chains(model: SklearnModel, ax=None, title="Across Chain Variation", x_label=False, X=None, y=None):
    if ax is None:
        _, ax = plt.subplots(1, 1)
    n_chains = model.n_chains

    preds = []
    mat = np.zeros(shape=(n_chains, n_chains))

    for c in range(n_chains):
        preds.append(model.predict_chain(X, c))
    for c_i, c_j in itertools.combinations(range(n_chains), 2):
        mat[c_i, c_j] = mean_squared_error(preds[c_i], preds[c_j])
        mat[c_j, c_i] = mean_squared_error(preds[c_i], preds[c_j])
    ax.matshow(mat)
    for c_i, c_j in itertools.combinations(range(n_chains), 2):
        c = np.round(mat[c_i, c_j], 2)
        ax.text(c_i, c_j, c, va='center', ha='center')
        ax.text(c_j, c_i, c, va='center', ha='center')
    ax.set_xlabel("mean squared distance between predictions")

    ax.legend()
    ax.set_title(f"{title} (Between Chains Var {np.round(model.between_chains_var(X), 2)})")
    return ax
def plot_chains_leaves(model: SklearnModel, ax=None, title='Tree Structure/Prediction Variation', x_label=False, X=None, y=None)
Expand source code
def plot_chains_leaves(model: SklearnModel, ax=None, title="Tree Structure/Prediction Variation", x_label=False, X=None,
                       y=None):
    if ax is None:
        _, ax = plt.subplots(1, 1)
    complexity = {i: [] for i in range(model.n_trees)}
    n_chains = model.n_chains
    for sample in model.model_samples:
        for i, tree in enumerate(sample.trees):
            complexity[i].append(len(tree.leaf_nodes))

    chain_len = int(len(model.model_samples) / n_chains)
    color = iter(cm.rainbow(np.linspace(0, 1, n_chains)))

    for c in range(n_chains):
        clr = next(color)
        chain_preds = model.predict_chain(X, c)
        chain_std = np.round(model.chain_mse_std(X, y, c), 2)
        mse_chain = np.round(mean_squared_error(chain_preds, y), 2)

        trees_chain = np.stack([complexity[t][c * chain_len:(c + 1) * chain_len] for t in range(model.n_trees)], axis=1)
        y_plt = np.mean(trees_chain, axis=1)
        ax.plot(np.arange(chain_len), y_plt, color=clr, label=f"Chain {c} (mse: {mse_chain} std: {chain_std})")

    ax.set_ylabel("# Leaves")
    if x_label:
        ax.set_xlabel("Iteration")
    ax.legend()
    ax.set_title(title)
    return ax
def plot_within_chain(model: SklearnModel, ax=None, title='Within Chain Variation', x_label=False, X=None, y=None)
Expand source code
def plot_within_chain(model: SklearnModel, ax=None, title="Within Chain Variation", x_label=False, X=None, y=None):
    if ax is None:
        _, ax = plt.subplots(1, 1)
    n_chains = model.n_chains

    chain_len = int(len(model.model_samples) / n_chains)
    color = iter(cm.rainbow(np.linspace(0, 1, n_chains)))

    for c in range(n_chains):
        clr = next(color)
        chain_preds = model.chain_predictions(X, c)
        mean_pred = np.array(chain_preds).mean(axis=0)

        y_plt = [mean_squared_error(mean_pred, p) for p in chain_preds]
        ax.plot(np.arange(chain_len), y_plt, color=clr, label=f"Chain {c} (Average {np.round(np.mean(y_plt), 2)})")

    ax.set_ylabel("mean squared distance to average iteration")
    if x_label:
        ax.set_xlabel("Iteration")
    ax.legend()
    ax.set_title(title)
    return ax