Expand source code
import numpy as np
from matplotlib import pyplot as plt
from ..sklearnmodel import SklearnModel
def plot_tree_depth(model: SklearnModel, ax=None, title="", x_label=False):
if ax is None:
_, ax = plt.subplots(1, 1)
# min_depth, mean_depth, max_depth = [], [], []
complexity = {i:[] for i in range(model.n_trees)}
for sample in model.model_samples:
# model_depths = []
for i, tree in enumerate(sample.trees):
complexity[i].append(len(tree.leaf_nodes))
# model_depths += [x.depth for x in tree.nodes]
# min_depth.append(np.min(model_depths))
# mean_depth.append(np.mean(model_depths))
# max_depth.append(np.max(model_depths))
# ax.plot(min_depth, label="Min Depth")
# ax.plot(mean_depth, label="Mean Depth")
# ax.plot(max_depth, label="Max Depth")
for tree_number, comp in complexity.items():
ax.plot(np.arange(len(model.model_samples)), comp, label=f"tree {tree_number}", alpha=0.5)
ax.set_ylabel("# Leaves")
if x_label:
ax.set_xlabel("Iteration")
ax.legend()
ax.set_title(title)
return ax
Functions
def plot_tree_depth(model: SklearnModel, ax=None, title='', x_label=False)
-
Expand source code
def plot_tree_depth(model: SklearnModel, ax=None, title="", x_label=False): if ax is None: _, ax = plt.subplots(1, 1) # min_depth, mean_depth, max_depth = [], [], [] complexity = {i:[] for i in range(model.n_trees)} for sample in model.model_samples: # model_depths = [] for i, tree in enumerate(sample.trees): complexity[i].append(len(tree.leaf_nodes)) # model_depths += [x.depth for x in tree.nodes] # min_depth.append(np.min(model_depths)) # mean_depth.append(np.mean(model_depths)) # max_depth.append(np.max(model_depths)) # ax.plot(min_depth, label="Min Depth") # ax.plot(mean_depth, label="Mean Depth") # ax.plot(max_depth, label="Max Depth") for tree_number, comp in complexity.items(): ax.plot(np.arange(len(model.model_samples)), comp, label=f"tree {tree_number}", alpha=0.5) ax.set_ylabel("# Leaves") if x_label: ax.set_xlabel("Iteration") ax.legend() ax.set_title(title) return ax