Scikit-learn friendly library to explain, predict, and steer text models/data.
Also a bunch of utilities for getting started with text data.

📖 demo notebooks

Explainable modeling/steering

Model Reference Output Description
Tree-Prompt 📖, 🗂️, 🔗, 📄 Explanation
+ Steering
Generates a tree of prompts to
steer an LLM (Official)
iPrompt 📖, 🗂️, 🔗, 📄 Explanation
+ Steering
Generates a prompt that
explains patterns in data (Official)
D3 📖, 🗂️, 🔗, 📄 Explanation Explain the difference between two distributions
SASC        ㅤㅤ🗂️, 🔗, 📄 Explanation Explain a black-box text module
using an LLM (Official)
AutoPrompt        ㅤㅤ🗂️, 🔗, 📄 Explanation Find a natural-language prompt
using input-gradients (⌛ In progress)
Aug-GAM 📖, 🗂️, 🔗, 📄 Linear model Fit better linear model using an LLM
to extract embeddings (Official)
Aug-Tree 📖, 🗂️, 🔗, 📄 Decision tree Fit better decision tree using an LLM
to expand features (Official)

📖Demo notebooks   🗂️ Doc   🔗 Reference code   📄 Research paper
⌛ We plan to support other interpretable algorithms like RLPrompt, CBMs, and NBDT. If you want to contribute an algorithm, feel free to open a PR 😄

General utilities

Model Reference
🗂️ LLM wrapper Easily call different LLMs
🗂️ Dataset wrapper Download minimially processed huggingface datasets
🗂️ Bag of Ngrams Learn a linear model of ngrams
🗂️ Linear Finetune Finetune a single linear layer on top of LLM embeddings

Quickstart

Installation: pip install imodelsx (or, for more control, clone and install from source)

Demos: see the demo notebooks

Explainable models

Natural-language explanations

iPrompt

from imodelsx import explain_dataset_iprompt, get_add_two_numbers_dataset

# get a simple dataset of adding two numbers
input_strings, output_strings = get_add_two_numbers_dataset(num_examples=100)
for i in range(5):
    print(repr(input_strings[i]), repr(output_strings[i]))

# explain the relationship between the inputs and outputs
# with a natural-language prompt string
prompts, metadata = explain_dataset_iprompt(
    input_strings=input_strings,
    output_strings=output_strings,
    checkpoint='EleutherAI/gpt-j-6B', # which language model to use
    num_learned_tokens=3, # how long of a prompt to learn
    n_shots=3, # shots per example
    n_epochs=15, # how many epochs to search
    verbose=0, # how much to print
    llm_float16=True, # whether to load the model in float_16
)
--------
prompts is a list of found natural-language prompt strings

D3 (DescribeDistributionalDifferences)

from imodelsx import explain_dataset_d3
hypotheses, hypothesis_scores = explain_dataset_d3(
    pos=positive_samples, # List[str] of positive examples
    neg=negative_samples, # another List[str]
    num_steps=100,
    num_folds=2,
    batch_size=64,
)

SASC

Here, we explain a module rather than a dataset

from imodelsx import explain_module_sasc
# a toy module that responds to the length of a string
mod = lambda str_list: np.array([len(s) for s in str_list])

# a toy dataset where the longest strings are animals
text_str_list = ["red", "blue", "x", "1", "2", "hippopotamus", "elephant", "rhinoceros"]
explanation_dict = explain_module_sasc(
    text_str_list,
    mod,
    ngrams=1,
)

Aug-imodels

Use these just a like a scikit-learn model. During training, they fit better features via LLMs, but at test-time they are extremely fast and completely transparent.

from imodelsx import AugGAMClassifier, AugTreeClassifier, AugGAMRegressor, AugTreeRegressor
import datasets
import numpy as np

# set up data
dset = datasets.load_dataset('rotten_tomatoes')['train']
dset = dset.select(np.random.choice(len(dset), size=300, replace=False))
dset_val = datasets.load_dataset('rotten_tomatoes')['validation']
dset_val = dset_val.select(np.random.choice(len(dset_val), size=300, replace=False))

# fit model
m = AugGAMClassifier(
    checkpoint='textattack/distilbert-base-uncased-rotten-tomatoes',
    ngrams=2, # use bigrams
)
m.fit(dset['text'], dset['label'])

# predict
preds = m.predict(dset_val['text'])
print('acc_val', np.mean(preds == dset_val['label']))

# interpret
print('Total ngram coefficients: ', len(m.coefs_dict_))
print('Most positive ngrams')
for k, v in sorted(m.coefs_dict_.items(), key=lambda item: item[1], reverse=True)[:8]:
    print('\t', k, round(v, 2))
print('Most negative ngrams')
for k, v in sorted(m.coefs_dict_.items(), key=lambda item: item[1])[:8]:
    print('\t', k, round(v, 2))

General utilities

Easy baselines

Easy-to-fit baselines that follow the sklearn API.

from imodelsx import LinearFinetuneClassifier, LinearNgramClassifier
# fit a simple one-layer finetune on top of LLM embeddings
m = LinearFinetuneClassifier(
    checkpoint='distilbert-base-uncased',
)
m.fit(dset['text'], dset['label'])
preds = m.predict(dset_val['text'])
acc = (preds == dset_val['label']).mean()
print('validation acc', acc)

LLM wrapper

Easy API for calling different language models with caching (much more lightweight than langchain).

import imodelsx.llm
# supports any huggingface checkpoint or openai checkpoint (including chat models)
llm = imodelsx.llm.get_llm(
    checkpoint="gpt2-xl",  # text-davinci-003, gpt-3.5-turbo, ...
    CACHE_DIR=".cache",
)
out = llm("May the Force be")
llm("May the Force be") # when computing the same string again, uses the cache

Data wrapper

API for loading huggingface datasets with basic preprocessing.

import imodelsx.data
dset, dataset_key_text = imodelsx.data.load_huggingface_dataset('ag_news')
# Ensures that dset has a split named 'train' and 'validation',
# and that the input data is contained for each split in a column given by {dataset_key_text}

Related work

  • imodels package (JOSS 2021 github) - interpretable ML package for concise, transparent, and accurate predictive modeling (sklearn-compatible).
  • Adaptive wavelet distillation (NeurIPS 2021 pdf, github) - distilling a neural network into a concise wavelet model
  • Transformation importance (ICLR 2020 workshop pdf, github) - using simple reparameterizations, allows for calculating disentangled importances to transformations of the input (e.g. assigning importances to different frequencies)
  • Hierarchical interpretations (ICLR 2019 pdf, github) - extends CD to CNNs / arbitrary DNNs, and aggregates explanations into a hierarchy
  • Interpretation regularization (ICML 2020 pdf, github) - penalizes CD / ACD scores during training to make models generalize better
  • PDR interpretability framework (PNAS 2019 pdf) - an overarching framewwork for guiding and framing interpretable machine learning
Expand source code
"""
.. include:: ../readme.md
"""

from .auggam.auggam import AugGAMClassifier, AugGAMRegressor
from .augtree.augtree import AugTreeClassifier, AugTreeRegressor
from .linear_finetune import LinearFinetuneClassifier, LinearFinetuneRegressor
from .linear_ngram import LinearNgramClassifier, LinearNgramRegressor
from .d3.d3 import explain_dataset_d3
from .iprompt.api import explain_dataset_iprompt
from .iprompt.data import get_add_two_numbers_dataset
from .sasc.api import explain_module_sasc
from .treeprompt.treeprompt import TreePromptClassifier

Sub-modules

imodelsx.auggam
imodelsx.augtree
imodelsx.cache_save_utils
imodelsx.d3
imodelsx.data
imodelsx.dummy_script
imodelsx.embeddings
imodelsx.iprompt
imodelsx.linear_finetune

Simple scikit-learn interface for finetuning a single linear layer on top of LLM embeddings.

imodelsx.linear_ngram

Simple scikit-learn interface for finetuning a single linear layer on top of LLM embeddings.

imodelsx.llm
imodelsx.metrics
imodelsx.process_results
imodelsx.sasc
imodelsx.submit_utils
imodelsx.treeprompt
imodelsx.util
imodelsx.viz