Expand source code
from collections import Counter
from typing import List
from imodels.util.rule import Rule
def extract_ensemble(weak_learners, X, y, min_multiplicity: int = 1) -> List[Rule]:
all_rules = []
all_subterms = []
for est in weak_learners:
est.fit(X, y)
all_rules += est.rules_
all_est_subterms = set([indv_r for r in est.rules_ for indv_r in split(r)])
all_subterms += all_est_subterms
if min_multiplicity > 0:
# round rule decision boundaries to increase matching
for i in range(len(all_rules)):
for key in all_rules[i].agg_dict:
all_rules[i].agg_dict[key] = round(float(all_rules[i].agg_dict[key]), 1)
# match full_rules
repeated_full_rules_counter = {k: v for k, v in Counter(all_rules).items() if v > min_multiplicity}
repeated_rules = set(repeated_full_rules_counter.keys())
# match subterms of rules
repeated_subterm_counter = {k: v for k, v in Counter(all_subterms).items() if v > min_multiplicity}
repeated_rules = repeated_rules.union(set(repeated_subterm_counter.keys()))
# convert to str form to be rescored
repeated_rules = list(map(str, repeated_rules))
return repeated_rules
def split(rule: Rule) -> List[Rule]:
if len(rule.agg_dict) == 1:
return [rule]
else:
indv_rule_strs = list(map(lambda x: ' '.join(x), rule.terms))
indv_rules = list(map(lambda x: Rule(x), indv_rule_strs))
return indv_rules
Functions
def extract_ensemble(weak_learners, X, y, min_multiplicity: int = 1) ‑> List[Rule]
-
Expand source code
def extract_ensemble(weak_learners, X, y, min_multiplicity: int = 1) -> List[Rule]: all_rules = [] all_subterms = [] for est in weak_learners: est.fit(X, y) all_rules += est.rules_ all_est_subterms = set([indv_r for r in est.rules_ for indv_r in split(r)]) all_subterms += all_est_subterms if min_multiplicity > 0: # round rule decision boundaries to increase matching for i in range(len(all_rules)): for key in all_rules[i].agg_dict: all_rules[i].agg_dict[key] = round(float(all_rules[i].agg_dict[key]), 1) # match full_rules repeated_full_rules_counter = {k: v for k, v in Counter(all_rules).items() if v > min_multiplicity} repeated_rules = set(repeated_full_rules_counter.keys()) # match subterms of rules repeated_subterm_counter = {k: v for k, v in Counter(all_subterms).items() if v > min_multiplicity} repeated_rules = repeated_rules.union(set(repeated_subterm_counter.keys())) # convert to str form to be rescored repeated_rules = list(map(str, repeated_rules)) return repeated_rules
def split(rule: Rule) ‑> List[Rule]
-
Expand source code
def split(rule: Rule) -> List[Rule]: if len(rule.agg_dict) == 1: return [rule] else: indv_rule_strs = list(map(lambda x: ' '.join(x), rule.terms)) indv_rules = list(map(lambda x: Rule(x), indv_rule_strs)) return indv_rules