Module imodelsx.viz
Functions
def extract_sklearn_tree_from_llm_tree(llm_tree, n_classes, with_leaf_predictions=False, dtreeviz_dummies=False)
-
Expand source code
def extract_sklearn_tree_from_llm_tree( llm_tree, n_classes, with_leaf_predictions=False, dtreeviz_dummies=False, ): """Takes in a Tree model and convert tree tree_num to a sklearn decision tree """ tree_data_namedtuple, value_sklearn_array, strs_array = \ _extract_arrays_from_llm_tree( llm_tree, dtreeviz_dummies=dtreeviz_dummies) # for k in tree_data_namedtuple._fields: # print(k) # print(tree_data_namedtuple.__getattribute__(k)) # manipulate tree_data_namedtuple into the numpy array of tuples # that sklearn expects for use with __setstate__() df_tree_data = pd.DataFrame(tree_data_namedtuple._asdict()) tree_data_list_of_tuples = list( df_tree_data.itertuples(index=False, name=None)) _dtypes = np.dtype([('left_child', 'i8'), ('right_child', 'i8'), ('feature', 'i8'), ('threshold', 'f8'), ('impurity', 'f8'), ('n_node_samples', 'i8'), ('weighted_n_node_samples', 'f8')]) tree_data_array = np.array(tree_data_list_of_tuples, dtype=_dtypes) # reshape value_sklearn_array to match the expected shape of (n_nodes,1,2) for values value_sklearns = value_sklearn_array.reshape( value_sklearn_array.shape[0], 1, value_sklearn_array.shape[1]) if n_classes == 1: value_sklearns = np.ascontiguousarray(value_sklearns[:, :, 0:1]) # get the max_depth def get_max_depth(node): if node is None: return -1 else: return 1 + max(get_max_depth(node.child_left), get_max_depth(node.child_right)) max_depth = get_max_depth(llm_tree.root_) # max_depth = 4 # get other variables needed for the sklearn.tree._tree.Tree constructor and __setstate__() calls # n_samples = np.sum(figs_tree.value_sklearn) node_count = len(tree_data_array) features = np.array(tree_data_namedtuple.feature) n_features = np.unique(features[np.where(0 <= features)]).size n_classes_array = np.array([n_classes], dtype=int) n_outputs = 1 # make dict to pass to __setstate__() _state = { 'max_depth': max_depth, 'node_count': node_count, 'nodes': tree_data_array, 'values': value_sklearns, # 'n_features_in_': llm_tree.n_features_in_, # WARNING this circumvents # UserWarning: Trying to unpickle estimator DecisionTreeClassifier # from version pre-0.18 when using version # https://github.com/scikit-learn/scikit-learn/blob/53acd0fe52cb5d8c6f5a86a1fc1352809240b68d/sklearn/base.py#L279 '_sklearn_version': __version__, } # print('state', _state) tree = Tree(n_features=n_features, n_classes=n_classes_array, n_outputs=n_outputs) # https://github.com/scikit-learn/scikit-learn/blob/3850935ea610b5231720fdf865c837aeff79ab1b/sklearn/tree/_tree.pyx#L677 tree.__setstate__(_state) # add the tree_ for the dt __setstate__() # note the trailing underscore also trips the sklearn_is_fitted protections _state['tree_'] = tree _state['classes_'] = np.arange(n_classes) _state['n_outputs_'] = n_outputs # construct sklearn object and __setstate__() # if isinstance(llm_tree, ClassifierMixin): dt = DecisionTreeClassifier(max_depth=max_depth) # elif isinstance(llm_tree, RegressorMixin): # dt = DecisionTreeRegressor(max_depth=max_depth) try: dt.__setstate__(_state) except: raise Exception( f'Did not successfully run __setstate__() when translating to {type(dt)}, did sklearn update?') if not with_leaf_predictions: return dt, strs_array else: leaf_values_dict = {} def _read_node(node): if node is None: return None elif node.left is None and node.right is None: leaf_values_dict[node.node_id] = node.value[0][0] _read_node(node.left) _read_node(node.right) _read_node(llm_tree) return dt, leaf_values_dict
Takes in a Tree model and convert tree tree_num to a sklearn decision tree