Module imodelsx.augtree.augtree
Expand source code
from typing import Dict, List
import numpy as np
import imodels
import imodelsx.augtree.llm
from imodelsx.augtree.embed import EmbsManager
from imodelsx.augtree.stump import Stump, StumpClassifier, StumpRegressor
import logging
import warnings
from sklearn.base import BaseEstimator, ClassifierMixin, RegressorMixin
class AugTree:
def __init__(
self,
max_depth: int = 3,
max_features=5,
split_strategy='cart',
refinement_strategy='None',
verbose=True,
tokenizer=None,
use_refine_ties=False,
assert_checks=False,
llm_prompt_context: str='',
use_stemming=False,
embs_manager: EmbsManager=None,
cache_expansions_dir: str=None,
):
'''
Params
------
max_depth: int
Maximum depth of the tree.
max_features: int
Number of features to consider expanding at each stump
split_strategy: str
Strategy for generating candidate seed keyphrases.
refinement_strategy: str
'None', 'llm', or 'embs'
verbose: bool
Whether to print debug statements
tokenizer
Tokenizer to use for splitting text into tokens
use_refine_ties: bool
Whether to include expanded keywords that don't improve or decrease performance
assert_checks: bool
Whether to run checks during fitting
llm_prompt_context: str
Extra context string provided llm_refine (if refinement_strategy=llm)
embs_manager
Class that provides function to query for keywords from closest embeddings
cache_expansions_dir: str
Directory to cache keyphrase expansions
'''
self.max_depth = max_depth
self.max_features = max_features
self.split_strategy = split_strategy
self.verbose = verbose
self.use_refine_ties = use_refine_ties
self.assert_checks = assert_checks
self.llm_prompt_context = llm_prompt_context
self.refinement_strategy = refinement_strategy
self.use_stemming = use_stemming
self.embs_manager = embs_manager
self.cache_expansions_dir = cache_expansions_dir
if tokenizer is None:
self.tokenizer = imodelsx.augtree.utils.get_spacy_tokenizer(use_stemming=use_stemming)
else:
self.tokenizer = tokenizer
assert self.refinement_strategy in ['None', 'llm', 'embs']
if self.refinement_strategy == 'embs':
assert embs_manager is not None, 'must pass embs_manager when refinement_strategy == "embs"'
def fit(self, X=None, y=None, feature_names=None, X_text=None):
if X is None and X_text:
warnings.warn("X is not passed, defaulting to generating unigrams from X_text")
X, _, feature_names = imodelsx.augtree.data.convert_text_data_to_counts_array(X_text, [], ngrams=1)
# check and set some attributes
X, y, _ = imodels.util.arguments.check_fit_arguments(
self, X, y, feature_names)
if isinstance(X_text, list):
X_text = np.array(X_text).flatten()
self.feature_names = feature_names
if isinstance(self.feature_names, list):
self.feature_names = np.array(self.feature_names).flatten()
# fit root stump
stump_kwargs = dict(
split_strategy=self.split_strategy,
max_features=self.max_features,
tokenizer=self.tokenizer,
use_refine_ties=self.use_refine_ties,
assert_checks=self.assert_checks,
llm_prompt_context=self.llm_prompt_context,
refinement_strategy=self.refinement_strategy,
embs_manager = self.embs_manager,
verbose=self.verbose,
use_stemming=self.use_stemming,
cache_expansions_dir=self.cache_expansions_dir,
)
# assume that the initial split finds a feature that provides some benefit
# otherwise, one leaf will end up NaN
if isinstance(self, RegressorMixin):
stump_class = StumpRegressor
else:
stump_class = StumpClassifier
stump = stump_class(**stump_kwargs).fit(
X, y,
feature_names=self.feature_names,
X_text=X_text
)
stump.idxs = np.ones(X.shape[0], dtype=bool)
self.root_ = stump
# recursively fit stumps and store as a decision tree
stumps_queue = [stump]
i = 0
depth = 1
while depth < self.max_depth:
stumps_queue_new = []
for stump in stumps_queue:
stump = stump
if self.verbose:
logging.debug(f'Splitting on depth={depth} stump_num={i} {stump.idxs.sum()}')
idxs_pred = stump.predict(X_text=X_text) > 0.5
for idxs_p, attr in zip([~idxs_pred, idxs_pred], ['child_left', 'child_right']):
# for idxs_p, attr in zip([idxs_pred], ['child_right']):
idxs_child = stump.idxs & idxs_p
if self.verbose:
logging.debug(f'\t{idxs_pred.sum()} {idxs_child.sum()}', len(np.unique(y[idxs_child])))
if idxs_child.sum() > 0 \
and idxs_child.sum() < stump.idxs.sum() \
and len(np.unique(y[idxs_child])) > 1:
# fit a potential child stump
stump_child = stump_class(**stump_kwargs).fit(
X[idxs_child], y[idxs_child],
X_text=X_text[idxs_child],
feature_names=self.feature_names,
)
# make sure the stump actually found a non-trivial split
if not stump_child.failed_to_split:
# set the child stump
stump_child.idxs = idxs_child
acc_tree_baseline = np.mean(self.predict(
X_text[idxs_child]) == y[idxs_child])
if attr == 'child_left':
stump.child_left = stump_child
else:
stump.child_right = stump_child
stumps_queue_new.append(stump_child)
if self.verbose:
logging.debug(f'\t\t {stump.stump_keywords} {stump.pos_or_neg}')
i += 1
######################### checks ###########################
if self.assert_checks and isinstance(self, ClassifierMixin):
# check acc for the points in this stump
acc_tree = np.mean(self.predict(
X_text[idxs_child]) == y[idxs_child])
assert acc_tree >= acc_tree_baseline, f'stump acc {acc_tree:0.3f} should be > after adding child {acc_tree_baseline:0.3f}'
# check total acc
acc_total_baseline = max(y.mean(), 1 - y.mean())
acc_total = np.mean(self.predict(X_text) == y)
assert acc_total >= acc_total_baseline, f'total acc {acc_total:0.3f} should be > after adding child {acc_total_baseline:0.3f}'
# check that stumptrain acc improved over this set
# not necessarily going to improve total acc, since the stump always predicts 0/1
# even though the correct answer might be always 0 or always be 1
acc_child_baseline = min(
y[idxs_child].mean(), 1 - y[idxs_child].mean())
assert stump_child.acc > acc_child_baseline, f'acc {stump_child.acc:0.3f} should be > baseline {acc_child_baseline:0.3f}'
stumps_queue = stumps_queue_new
depth += 1
return self
def predict_proba(self, X_text: List[str] = None):
preds = []
for x_t in X_text:
# prediction for single point
stump = self.root_
while stump:
# 0 or 1 class prediction here
pred = stump.predict(X_text=[x_t])[0]
value = stump.value
if pred > 0.5:
stump = stump.child_right
value = value[1]
else:
stump = stump.child_left
value = value[0]
if stump is None:
preds.append(value)
preds = np.array(preds)
probs = np.vstack((1 - preds, preds)).transpose() # probs (n, 2)
return probs
def predict(self, X_text: List[str] = None) -> np.ndarray[int]:
preds_continuous = self.predict_proba(X_text)[:, 1]
if isinstance(self, ClassifierMixin):
return (preds_continuous > 0.5).astype(int)
else:
return preds_continuous
def get_tree_dict_repr(self) -> Dict[str, List[str]]:
"""
Returns a dictionary representation of the tree
Each key is a binary prefix string
"0" for root
"00" for left child of root
"01" for right child of root
"000" for left child of left child of root, etc.
Each value is a list of strings, where each string is a keyword
"""
tree_dict = {}
stumps_queue = [(self.root_, "0")]
while stumps_queue:
stump, stump_id = stumps_queue.pop(0)
# skip leaf nodes
if stump.child_left is None and stump.child_right is None:
continue
if hasattr(stump, 'stump_keywords_refined'):
keywords = stump.stump_keywords_refined
else:
keywords = stump.stump_keywords
tree_dict[stump_id] = keywords
if stump.child_left:
stumps_queue.append((stump.child_left, stump_id + "0"))
if stump.child_right:
stumps_queue.append((stump.child_right, stump_id + "1"))
return tree_dict
def __str__(self):
s = f'> Tree(max_depth={self.max_depth} max_features={self.max_features} refine={self.refinement_strategy})\n> ------------------------------------------------------\n'
return s + self.viz_tree()
def viz_tree(self, stump: Stump=None, depth: int=0, s: str='') -> str:
if stump is None:
stump = self.root_
s += ' ' * depth + str(stump) + '\n'
if stump.child_left:
s += self.viz_tree(stump.child_left, depth + 1)
else:
s += ' ' * (depth + 1) + f'Neg n={stump.n_samples[0]} val={stump.value[0]:0.3f}' + '\n'
if stump.child_right:
s += self.viz_tree(stump.child_right, depth + 1)
else:
s += ' ' * (depth + 1) + f'Pos n={stump.n_samples[1]} val={stump.value[1]:0.3f}' + '\n'
return s
class AugTreeRegressor(AugTree, RegressorMixin):
...
class AugTreeClassifier(AugTree, ClassifierMixin):
...
Classes
class AugTree (max_depth: int = 3, max_features=5, split_strategy='cart', refinement_strategy='None', verbose=True, tokenizer=None, use_refine_ties=False, assert_checks=False, llm_prompt_context: str = '', use_stemming=False, embs_manager: EmbsManager = None, cache_expansions_dir: str = None)
-
Params
max_depth: int Maximum depth of the tree. max_features: int Number of features to consider expanding at each stump split_strategy: str Strategy for generating candidate seed keyphrases. refinement_strategy: str 'None', 'llm', or 'embs' verbose: bool Whether to print debug statements tokenizer Tokenizer to use for splitting text into tokens use_refine_ties: bool Whether to include expanded keywords that don't improve or decrease performance assert_checks: bool Whether to run checks during fitting llm_prompt_context: str Extra context string provided llm_refine (if refinement_strategy=llm) embs_manager Class that provides function to query for keywords from closest embeddings cache_expansions_dir: str Directory to cache keyphrase expansions
Expand source code
class AugTree: def __init__( self, max_depth: int = 3, max_features=5, split_strategy='cart', refinement_strategy='None', verbose=True, tokenizer=None, use_refine_ties=False, assert_checks=False, llm_prompt_context: str='', use_stemming=False, embs_manager: EmbsManager=None, cache_expansions_dir: str=None, ): ''' Params ------ max_depth: int Maximum depth of the tree. max_features: int Number of features to consider expanding at each stump split_strategy: str Strategy for generating candidate seed keyphrases. refinement_strategy: str 'None', 'llm', or 'embs' verbose: bool Whether to print debug statements tokenizer Tokenizer to use for splitting text into tokens use_refine_ties: bool Whether to include expanded keywords that don't improve or decrease performance assert_checks: bool Whether to run checks during fitting llm_prompt_context: str Extra context string provided llm_refine (if refinement_strategy=llm) embs_manager Class that provides function to query for keywords from closest embeddings cache_expansions_dir: str Directory to cache keyphrase expansions ''' self.max_depth = max_depth self.max_features = max_features self.split_strategy = split_strategy self.verbose = verbose self.use_refine_ties = use_refine_ties self.assert_checks = assert_checks self.llm_prompt_context = llm_prompt_context self.refinement_strategy = refinement_strategy self.use_stemming = use_stemming self.embs_manager = embs_manager self.cache_expansions_dir = cache_expansions_dir if tokenizer is None: self.tokenizer = imodelsx.augtree.utils.get_spacy_tokenizer(use_stemming=use_stemming) else: self.tokenizer = tokenizer assert self.refinement_strategy in ['None', 'llm', 'embs'] if self.refinement_strategy == 'embs': assert embs_manager is not None, 'must pass embs_manager when refinement_strategy == "embs"' def fit(self, X=None, y=None, feature_names=None, X_text=None): if X is None and X_text: warnings.warn("X is not passed, defaulting to generating unigrams from X_text") X, _, feature_names = imodelsx.augtree.data.convert_text_data_to_counts_array(X_text, [], ngrams=1) # check and set some attributes X, y, _ = imodels.util.arguments.check_fit_arguments( self, X, y, feature_names) if isinstance(X_text, list): X_text = np.array(X_text).flatten() self.feature_names = feature_names if isinstance(self.feature_names, list): self.feature_names = np.array(self.feature_names).flatten() # fit root stump stump_kwargs = dict( split_strategy=self.split_strategy, max_features=self.max_features, tokenizer=self.tokenizer, use_refine_ties=self.use_refine_ties, assert_checks=self.assert_checks, llm_prompt_context=self.llm_prompt_context, refinement_strategy=self.refinement_strategy, embs_manager = self.embs_manager, verbose=self.verbose, use_stemming=self.use_stemming, cache_expansions_dir=self.cache_expansions_dir, ) # assume that the initial split finds a feature that provides some benefit # otherwise, one leaf will end up NaN if isinstance(self, RegressorMixin): stump_class = StumpRegressor else: stump_class = StumpClassifier stump = stump_class(**stump_kwargs).fit( X, y, feature_names=self.feature_names, X_text=X_text ) stump.idxs = np.ones(X.shape[0], dtype=bool) self.root_ = stump # recursively fit stumps and store as a decision tree stumps_queue = [stump] i = 0 depth = 1 while depth < self.max_depth: stumps_queue_new = [] for stump in stumps_queue: stump = stump if self.verbose: logging.debug(f'Splitting on depth={depth} stump_num={i} {stump.idxs.sum()}') idxs_pred = stump.predict(X_text=X_text) > 0.5 for idxs_p, attr in zip([~idxs_pred, idxs_pred], ['child_left', 'child_right']): # for idxs_p, attr in zip([idxs_pred], ['child_right']): idxs_child = stump.idxs & idxs_p if self.verbose: logging.debug(f'\t{idxs_pred.sum()} {idxs_child.sum()}', len(np.unique(y[idxs_child]))) if idxs_child.sum() > 0 \ and idxs_child.sum() < stump.idxs.sum() \ and len(np.unique(y[idxs_child])) > 1: # fit a potential child stump stump_child = stump_class(**stump_kwargs).fit( X[idxs_child], y[idxs_child], X_text=X_text[idxs_child], feature_names=self.feature_names, ) # make sure the stump actually found a non-trivial split if not stump_child.failed_to_split: # set the child stump stump_child.idxs = idxs_child acc_tree_baseline = np.mean(self.predict( X_text[idxs_child]) == y[idxs_child]) if attr == 'child_left': stump.child_left = stump_child else: stump.child_right = stump_child stumps_queue_new.append(stump_child) if self.verbose: logging.debug(f'\t\t {stump.stump_keywords} {stump.pos_or_neg}') i += 1 ######################### checks ########################### if self.assert_checks and isinstance(self, ClassifierMixin): # check acc for the points in this stump acc_tree = np.mean(self.predict( X_text[idxs_child]) == y[idxs_child]) assert acc_tree >= acc_tree_baseline, f'stump acc {acc_tree:0.3f} should be > after adding child {acc_tree_baseline:0.3f}' # check total acc acc_total_baseline = max(y.mean(), 1 - y.mean()) acc_total = np.mean(self.predict(X_text) == y) assert acc_total >= acc_total_baseline, f'total acc {acc_total:0.3f} should be > after adding child {acc_total_baseline:0.3f}' # check that stumptrain acc improved over this set # not necessarily going to improve total acc, since the stump always predicts 0/1 # even though the correct answer might be always 0 or always be 1 acc_child_baseline = min( y[idxs_child].mean(), 1 - y[idxs_child].mean()) assert stump_child.acc > acc_child_baseline, f'acc {stump_child.acc:0.3f} should be > baseline {acc_child_baseline:0.3f}' stumps_queue = stumps_queue_new depth += 1 return self def predict_proba(self, X_text: List[str] = None): preds = [] for x_t in X_text: # prediction for single point stump = self.root_ while stump: # 0 or 1 class prediction here pred = stump.predict(X_text=[x_t])[0] value = stump.value if pred > 0.5: stump = stump.child_right value = value[1] else: stump = stump.child_left value = value[0] if stump is None: preds.append(value) preds = np.array(preds) probs = np.vstack((1 - preds, preds)).transpose() # probs (n, 2) return probs def predict(self, X_text: List[str] = None) -> np.ndarray[int]: preds_continuous = self.predict_proba(X_text)[:, 1] if isinstance(self, ClassifierMixin): return (preds_continuous > 0.5).astype(int) else: return preds_continuous def get_tree_dict_repr(self) -> Dict[str, List[str]]: """ Returns a dictionary representation of the tree Each key is a binary prefix string "0" for root "00" for left child of root "01" for right child of root "000" for left child of left child of root, etc. Each value is a list of strings, where each string is a keyword """ tree_dict = {} stumps_queue = [(self.root_, "0")] while stumps_queue: stump, stump_id = stumps_queue.pop(0) # skip leaf nodes if stump.child_left is None and stump.child_right is None: continue if hasattr(stump, 'stump_keywords_refined'): keywords = stump.stump_keywords_refined else: keywords = stump.stump_keywords tree_dict[stump_id] = keywords if stump.child_left: stumps_queue.append((stump.child_left, stump_id + "0")) if stump.child_right: stumps_queue.append((stump.child_right, stump_id + "1")) return tree_dict def __str__(self): s = f'> Tree(max_depth={self.max_depth} max_features={self.max_features} refine={self.refinement_strategy})\n> ------------------------------------------------------\n' return s + self.viz_tree() def viz_tree(self, stump: Stump=None, depth: int=0, s: str='') -> str: if stump is None: stump = self.root_ s += ' ' * depth + str(stump) + '\n' if stump.child_left: s += self.viz_tree(stump.child_left, depth + 1) else: s += ' ' * (depth + 1) + f'Neg n={stump.n_samples[0]} val={stump.value[0]:0.3f}' + '\n' if stump.child_right: s += self.viz_tree(stump.child_right, depth + 1) else: s += ' ' * (depth + 1) + f'Pos n={stump.n_samples[1]} val={stump.value[1]:0.3f}' + '\n' return s
Subclasses
Methods
def fit(self, X=None, y=None, feature_names=None, X_text=None)
-
Expand source code
def fit(self, X=None, y=None, feature_names=None, X_text=None): if X is None and X_text: warnings.warn("X is not passed, defaulting to generating unigrams from X_text") X, _, feature_names = imodelsx.augtree.data.convert_text_data_to_counts_array(X_text, [], ngrams=1) # check and set some attributes X, y, _ = imodels.util.arguments.check_fit_arguments( self, X, y, feature_names) if isinstance(X_text, list): X_text = np.array(X_text).flatten() self.feature_names = feature_names if isinstance(self.feature_names, list): self.feature_names = np.array(self.feature_names).flatten() # fit root stump stump_kwargs = dict( split_strategy=self.split_strategy, max_features=self.max_features, tokenizer=self.tokenizer, use_refine_ties=self.use_refine_ties, assert_checks=self.assert_checks, llm_prompt_context=self.llm_prompt_context, refinement_strategy=self.refinement_strategy, embs_manager = self.embs_manager, verbose=self.verbose, use_stemming=self.use_stemming, cache_expansions_dir=self.cache_expansions_dir, ) # assume that the initial split finds a feature that provides some benefit # otherwise, one leaf will end up NaN if isinstance(self, RegressorMixin): stump_class = StumpRegressor else: stump_class = StumpClassifier stump = stump_class(**stump_kwargs).fit( X, y, feature_names=self.feature_names, X_text=X_text ) stump.idxs = np.ones(X.shape[0], dtype=bool) self.root_ = stump # recursively fit stumps and store as a decision tree stumps_queue = [stump] i = 0 depth = 1 while depth < self.max_depth: stumps_queue_new = [] for stump in stumps_queue: stump = stump if self.verbose: logging.debug(f'Splitting on depth={depth} stump_num={i} {stump.idxs.sum()}') idxs_pred = stump.predict(X_text=X_text) > 0.5 for idxs_p, attr in zip([~idxs_pred, idxs_pred], ['child_left', 'child_right']): # for idxs_p, attr in zip([idxs_pred], ['child_right']): idxs_child = stump.idxs & idxs_p if self.verbose: logging.debug(f'\t{idxs_pred.sum()} {idxs_child.sum()}', len(np.unique(y[idxs_child]))) if idxs_child.sum() > 0 \ and idxs_child.sum() < stump.idxs.sum() \ and len(np.unique(y[idxs_child])) > 1: # fit a potential child stump stump_child = stump_class(**stump_kwargs).fit( X[idxs_child], y[idxs_child], X_text=X_text[idxs_child], feature_names=self.feature_names, ) # make sure the stump actually found a non-trivial split if not stump_child.failed_to_split: # set the child stump stump_child.idxs = idxs_child acc_tree_baseline = np.mean(self.predict( X_text[idxs_child]) == y[idxs_child]) if attr == 'child_left': stump.child_left = stump_child else: stump.child_right = stump_child stumps_queue_new.append(stump_child) if self.verbose: logging.debug(f'\t\t {stump.stump_keywords} {stump.pos_or_neg}') i += 1 ######################### checks ########################### if self.assert_checks and isinstance(self, ClassifierMixin): # check acc for the points in this stump acc_tree = np.mean(self.predict( X_text[idxs_child]) == y[idxs_child]) assert acc_tree >= acc_tree_baseline, f'stump acc {acc_tree:0.3f} should be > after adding child {acc_tree_baseline:0.3f}' # check total acc acc_total_baseline = max(y.mean(), 1 - y.mean()) acc_total = np.mean(self.predict(X_text) == y) assert acc_total >= acc_total_baseline, f'total acc {acc_total:0.3f} should be > after adding child {acc_total_baseline:0.3f}' # check that stumptrain acc improved over this set # not necessarily going to improve total acc, since the stump always predicts 0/1 # even though the correct answer might be always 0 or always be 1 acc_child_baseline = min( y[idxs_child].mean(), 1 - y[idxs_child].mean()) assert stump_child.acc > acc_child_baseline, f'acc {stump_child.acc:0.3f} should be > baseline {acc_child_baseline:0.3f}' stumps_queue = stumps_queue_new depth += 1 return self
def get_tree_dict_repr(self) ‑> Dict[str, List[str]]
-
Returns a dictionary representation of the tree Each key is a binary prefix string "0" for root "00" for left child of root "01" for right child of root "000" for left child of left child of root, etc. Each value is a list of strings, where each string is a keyword
Expand source code
def get_tree_dict_repr(self) -> Dict[str, List[str]]: """ Returns a dictionary representation of the tree Each key is a binary prefix string "0" for root "00" for left child of root "01" for right child of root "000" for left child of left child of root, etc. Each value is a list of strings, where each string is a keyword """ tree_dict = {} stumps_queue = [(self.root_, "0")] while stumps_queue: stump, stump_id = stumps_queue.pop(0) # skip leaf nodes if stump.child_left is None and stump.child_right is None: continue if hasattr(stump, 'stump_keywords_refined'): keywords = stump.stump_keywords_refined else: keywords = stump.stump_keywords tree_dict[stump_id] = keywords if stump.child_left: stumps_queue.append((stump.child_left, stump_id + "0")) if stump.child_right: stumps_queue.append((stump.child_right, stump_id + "1")) return tree_dict
def predict(self, X_text: List[str] = None) ‑> numpy.ndarray[int]
-
Expand source code
def predict(self, X_text: List[str] = None) -> np.ndarray[int]: preds_continuous = self.predict_proba(X_text)[:, 1] if isinstance(self, ClassifierMixin): return (preds_continuous > 0.5).astype(int) else: return preds_continuous
def predict_proba(self, X_text: List[str] = None)
-
Expand source code
def predict_proba(self, X_text: List[str] = None): preds = [] for x_t in X_text: # prediction for single point stump = self.root_ while stump: # 0 or 1 class prediction here pred = stump.predict(X_text=[x_t])[0] value = stump.value if pred > 0.5: stump = stump.child_right value = value[1] else: stump = stump.child_left value = value[0] if stump is None: preds.append(value) preds = np.array(preds) probs = np.vstack((1 - preds, preds)).transpose() # probs (n, 2) return probs
def viz_tree(self, stump: Stump = None, depth: int = 0, s: str = '') ‑> str
-
Expand source code
def viz_tree(self, stump: Stump=None, depth: int=0, s: str='') -> str: if stump is None: stump = self.root_ s += ' ' * depth + str(stump) + '\n' if stump.child_left: s += self.viz_tree(stump.child_left, depth + 1) else: s += ' ' * (depth + 1) + f'Neg n={stump.n_samples[0]} val={stump.value[0]:0.3f}' + '\n' if stump.child_right: s += self.viz_tree(stump.child_right, depth + 1) else: s += ' ' * (depth + 1) + f'Pos n={stump.n_samples[1]} val={stump.value[1]:0.3f}' + '\n' return s
class AugTreeClassifier (max_depth: int = 3, max_features=5, split_strategy='cart', refinement_strategy='None', verbose=True, tokenizer=None, use_refine_ties=False, assert_checks=False, llm_prompt_context: str = '', use_stemming=False, embs_manager: EmbsManager = None, cache_expansions_dir: str = None)
-
Mixin class for all classifiers in scikit-learn.
This mixin defines the following functionality:
_estimator_type
class attribute defaulting to"classifier"
;score
method that default to :func:~sklearn.metrics.accuracy_score
.- enforce that
fit
requiresy
to be passed through therequires_y
tag.
Read more in the :ref:
User Guide <rolling_your_own_estimator>
.Examples
>>> import numpy as np >>> from sklearn.base import BaseEstimator, ClassifierMixin >>> # Mixin classes should always be on the left-hand side for a correct MRO >>> class MyEstimator(ClassifierMixin, BaseEstimator): ... def __init__(self, *, param=1): ... self.param = param ... def fit(self, X, y=None): ... self.is_fitted_ = True ... return self ... def predict(self, X): ... return np.full(shape=X.shape[0], fill_value=self.param) >>> estimator = MyEstimator(param=1) >>> X = np.array([[1, 2], [2, 3], [3, 4]]) >>> y = np.array([1, 0, 1]) >>> estimator.fit(X, y).predict(X) array([1, 1, 1]) >>> estimator.score(X, y) 0.66...
Params
max_depth: int Maximum depth of the tree. max_features: int Number of features to consider expanding at each stump split_strategy: str Strategy for generating candidate seed keyphrases. refinement_strategy: str 'None', 'llm', or 'embs' verbose: bool Whether to print debug statements tokenizer Tokenizer to use for splitting text into tokens use_refine_ties: bool Whether to include expanded keywords that don't improve or decrease performance assert_checks: bool Whether to run checks during fitting llm_prompt_context: str Extra context string provided llm_refine (if refinement_strategy=llm) embs_manager Class that provides function to query for keywords from closest embeddings cache_expansions_dir: str Directory to cache keyphrase expansions
Expand source code
class AugTreeClassifier(AugTree, ClassifierMixin): ...
Ancestors
- AugTree
- sklearn.base.ClassifierMixin
Inherited members
class AugTreeRegressor (max_depth: int = 3, max_features=5, split_strategy='cart', refinement_strategy='None', verbose=True, tokenizer=None, use_refine_ties=False, assert_checks=False, llm_prompt_context: str = '', use_stemming=False, embs_manager: EmbsManager = None, cache_expansions_dir: str = None)
-
Mixin class for all regression estimators in scikit-learn.
This mixin defines the following functionality:
_estimator_type
class attribute defaulting to"regressor"
;score
method that default to :func:~sklearn.metrics.r2_score
.- enforce that
fit
requiresy
to be passed through therequires_y
tag.
Read more in the :ref:
User Guide <rolling_your_own_estimator>
.Examples
>>> import numpy as np >>> from sklearn.base import BaseEstimator, RegressorMixin >>> # Mixin classes should always be on the left-hand side for a correct MRO >>> class MyEstimator(RegressorMixin, BaseEstimator): ... def __init__(self, *, param=1): ... self.param = param ... def fit(self, X, y=None): ... self.is_fitted_ = True ... return self ... def predict(self, X): ... return np.full(shape=X.shape[0], fill_value=self.param) >>> estimator = MyEstimator(param=0) >>> X = np.array([[1, 2], [2, 3], [3, 4]]) >>> y = np.array([-1, 0, 1]) >>> estimator.fit(X, y).predict(X) array([0, 0, 0]) >>> estimator.score(X, y) 0.0
Params
max_depth: int Maximum depth of the tree. max_features: int Number of features to consider expanding at each stump split_strategy: str Strategy for generating candidate seed keyphrases. refinement_strategy: str 'None', 'llm', or 'embs' verbose: bool Whether to print debug statements tokenizer Tokenizer to use for splitting text into tokens use_refine_ties: bool Whether to include expanded keywords that don't improve or decrease performance assert_checks: bool Whether to run checks during fitting llm_prompt_context: str Extra context string provided llm_refine (if refinement_strategy=llm) embs_manager Class that provides function to query for keywords from closest embeddings cache_expansions_dir: str Directory to cache keyphrase expansions
Expand source code
class AugTreeRegressor(AugTree, RegressorMixin): ...
Ancestors
- AugTree
- sklearn.base.RegressorMixin
Inherited members