
Expand source code
import torch
import torch.nn.functional as F
import numpy as np
from scipy.special import expit as sigmoid
from .cd_propagate import *
from .cd_architecture_specific import *

def cd(im_torch: torch.Tensor, model, mask=None, model_type=None, device='cuda', transform=None):
    '''Get contextual decomposition scores for some set of inputs for a specific image
    im_torch: torch.Tensor
        example to interpret - usually has shape (batch_size, num_channels, height, width)
    model: pytorch model        
    mask: array_like (values in {0, 1})
        required unless transform is supplied
        array with 1s marking the locations of relevant pixels, 0s marking the background
        shape should match the shape of im_torch or just H x W        
    model_type: str, optional
        usually should just leave this blank
        if this is == 'mnist', uses CD for a specific mnist model
        if this is == 'resnet18', uses resnet18 model
    device: str, optional
    transform: function, optional
        transform should be a function which transforms the original image to specify rel
        only used if mask is not passed
    relevant: torch.Tensor
        class-wise scores for relevant mask
    irrelevant: torch.Tensor
        class-wise scores for everything but the relevant mask 
    # set up model
    model =
    im_torch =

    # set up relevant/irrelevant based on mask
    if mask is not None:
        mask = torch.FloatTensor(mask).to(device)
        relevant = mask * im_torch
        irrelevant = (1 - mask) * im_torch
    elif transform is not None:
        relevant = transform(im_torch).to(device)
        if len(relevant.shape) < 4:
            relevant = relevant.reshape(1, 1, relevant.shape[0], relevant.shape[1])
        irrelevant = im_torch - relevant
        print('mask or transform arguments required!')
    relevant =
    irrelevant =

    # deal with specific architectures which cannot be handled generically
    if model_type == 'mnist':
        return cd_propagate_mnist(relevant, irrelevant, model)
    elif model_type == 'resnet18':
        return cd_propagate_resnet(relevant, irrelevant, model)

    # try the generic case
        mods = list(model.modules())
        relevant, irrelevant = cd_generic(mods, relevant, irrelevant)
    return relevant, irrelevant

def cd_generic(mods, relevant, irrelevant):
    '''Helper function for cd which loops over modules and propagates them 
    based on the layer name
    for i, mod in enumerate(mods):
        t = str(type(mod))
        if 'Conv2d' in t:
            relevant, irrelevant = propagate_conv_linear(relevant, irrelevant, mod)
        elif 'Linear' in t:
            relevant = relevant.reshape(relevant.shape[0], -1)
            irrelevant = irrelevant.reshape(irrelevant.shape[0], -1)
            relevant, irrelevant = propagate_conv_linear(relevant, irrelevant, mod)
        elif 'ReLU' in t:
            relevant, irrelevant = propagate_relu(relevant, irrelevant, mod)
        elif 'AvgPool' in t or 'NormLayer' in t or 'Dropout' in t \
                or 'ReshapeLayer' in t or ('modularize' in t and 'Transform' in t):  # custom layers
            relevant, irrelevant = propagate_independent(relevant, irrelevant, mod)
        elif 'Pool' in t and not 'AvgPool' in t:
            relevant, irrelevant = propagate_pooling(relevant, irrelevant, mod)
        elif 'BatchNorm2d' in t:
            relevant, irrelevant = propagate_batchnorm2d(relevant, irrelevant, mod)
    return relevant, irrelevant

def cd_text(batch, model, start, stop, return_irrel_scores=False):
    '''Get contextual decomposition scores for substring of a text sequence
        batch: torchtext batch
            really only requires that batch.text is the string input to be interpreted
        start: int
            beginning index of substring to be interpreted (inclusive)
        stop: int
            ending index of substring to be interpreted (inclusive)

        scores: torch.Tensor
            class-wise scores for relevant substring
    weights = model.lstm.state_dict()

    # Index one = word vector (i) or hidden state (h), index two = gate
    W_ii, W_if, W_ig, W_io = np.split(weights['weight_ih_l0'], 4, 0)
    W_hi, W_hf, W_hg, W_ho = np.split(weights['weight_hh_l0'], 4, 0)
    b_i, b_f, b_g, b_o = np.split(weights['bias_ih_l0'].cpu().numpy() + weights['bias_hh_l0'].cpu().numpy(), 4)
    word_vecs = model.embed(batch.text)[:, 0].data
    T = word_vecs.size(0)
    relevant = np.zeros((T, model.hidden_dim))
    irrelevant = np.zeros((T, model.hidden_dim))
    relevant_h = np.zeros((T, model.hidden_dim))
    irrelevant_h = np.zeros((T, model.hidden_dim))
    for i in range(T):
        if i > 0:
            prev_rel_h = relevant_h[i - 1]
            prev_irrel_h = irrelevant_h[i - 1]
            prev_rel_h = np.zeros(model.hidden_dim)
            prev_irrel_h = np.zeros(model.hidden_dim)

        rel_i =, prev_rel_h)
        rel_g =, prev_rel_h)
        rel_f =, prev_rel_h)
        rel_o =, prev_rel_h)
        irrel_i =, prev_irrel_h)
        irrel_g =, prev_irrel_h)
        irrel_f =, prev_irrel_h)
        irrel_o =, prev_irrel_h)

        if i >= start and i <= stop:
            rel_i = rel_i +, word_vecs[i])
            rel_g = rel_g +, word_vecs[i])
            rel_f = rel_f +, word_vecs[i])
            rel_o = rel_o +, word_vecs[i])
            irrel_i = irrel_i +, word_vecs[i])
            irrel_g = irrel_g +, word_vecs[i])
            irrel_f = irrel_f +, word_vecs[i])
            irrel_o = irrel_o +, word_vecs[i])

        rel_contrib_i, irrel_contrib_i, bias_contrib_i = propagate_three(rel_i, irrel_i, b_i, sigmoid)
        rel_contrib_g, irrel_contrib_g, bias_contrib_g = propagate_three(rel_g, irrel_g, b_g, np.tanh)

        relevant[i] = rel_contrib_i * (rel_contrib_g + bias_contrib_g) + bias_contrib_i * rel_contrib_g
        irrelevant[i] = irrel_contrib_i * (rel_contrib_g + irrel_contrib_g + bias_contrib_g) + (
                    rel_contrib_i + bias_contrib_i) * irrel_contrib_g

        if i >= start and i <= stop:
            relevant[i] += bias_contrib_i * bias_contrib_g
            irrelevant[i] += bias_contrib_i * bias_contrib_g

        if i > 0:
            rel_contrib_f, irrel_contrib_f, bias_contrib_f = propagate_three(rel_f, irrel_f, b_f, sigmoid)
            relevant[i] += (rel_contrib_f + bias_contrib_f) * relevant[i - 1]
            irrelevant[i] += (rel_contrib_f + irrel_contrib_f + bias_contrib_f) * irrelevant[i - 1] + irrel_contrib_f * \
                             relevant[i - 1]

        o = sigmoid(, word_vecs[i]) +, prev_rel_h + prev_irrel_h) + b_o)
        rel_contrib_o, irrel_contrib_o, bias_contrib_o = propagate_three(rel_o, irrel_o, b_o, sigmoid)
        new_rel_h, new_irrel_h = propagate_tanh_two(relevant[i], irrelevant[i])
        # relevant_h[i] = new_rel_h * (rel_contrib_o + bias_contrib_o)
        # irrelevant_h[i] = new_rel_h * (irrel_contrib_o) + new_irrel_h * (rel_contrib_o + irrel_contrib_o + bias_contrib_o)
        relevant_h[i] = o * new_rel_h
        irrelevant_h[i] = o * new_irrel_h

    W_out =

    # Sanity check: scores + irrel_scores should equal the LSTM's output minus model.hidden_to_label.bias
    scores =, relevant_h[T - 1])
    irrel_scores =, irrelevant_h[T - 1])

    if return_irrel_scores:
        return scores, irrel_scores

    return scores


def cd(im_torch, model, mask=None, model_type=None, device='cuda', transform=None)

Get contextual decomposition scores for some set of inputs for a specific image


im_torch : torch.Tensor
example to interpret - usually has shape (batch_size, num_channels, height, width)
model : pytorch model
mask : array_like (values in {0, 1})
required unless transform is supplied array with 1s marking the locations of relevant pixels, 0s marking the background shape should match the shape of im_torch or just H x W
model_type : str, optional
usually should just leave this blank if this is == 'mnist', uses CD for a specific mnist model if this is == 'resnet18', uses resnet18 model
device : str, optional
transform : function, optional
transform should be a function which transforms the original image to specify rel only used if mask is not passed


relevant : torch.Tensor
class-wise scores for relevant mask
irrelevant : torch.Tensor
class-wise scores for everything but the relevant mask
Expand source code
def cd(im_torch: torch.Tensor, model, mask=None, model_type=None, device='cuda', transform=None):
    '''Get contextual decomposition scores for some set of inputs for a specific image
    im_torch: torch.Tensor
        example to interpret - usually has shape (batch_size, num_channels, height, width)
    model: pytorch model        
    mask: array_like (values in {0, 1})
        required unless transform is supplied
        array with 1s marking the locations of relevant pixels, 0s marking the background
        shape should match the shape of im_torch or just H x W        
    model_type: str, optional
        usually should just leave this blank
        if this is == 'mnist', uses CD for a specific mnist model
        if this is == 'resnet18', uses resnet18 model
    device: str, optional
    transform: function, optional
        transform should be a function which transforms the original image to specify rel
        only used if mask is not passed
    relevant: torch.Tensor
        class-wise scores for relevant mask
    irrelevant: torch.Tensor
        class-wise scores for everything but the relevant mask 
    # set up model
    model =
    im_torch =

    # set up relevant/irrelevant based on mask
    if mask is not None:
        mask = torch.FloatTensor(mask).to(device)
        relevant = mask * im_torch
        irrelevant = (1 - mask) * im_torch
    elif transform is not None:
        relevant = transform(im_torch).to(device)
        if len(relevant.shape) < 4:
            relevant = relevant.reshape(1, 1, relevant.shape[0], relevant.shape[1])
        irrelevant = im_torch - relevant
        print('mask or transform arguments required!')
    relevant =
    irrelevant =

    # deal with specific architectures which cannot be handled generically
    if model_type == 'mnist':
        return cd_propagate_mnist(relevant, irrelevant, model)
    elif model_type == 'resnet18':
        return cd_propagate_resnet(relevant, irrelevant, model)

    # try the generic case
        mods = list(model.modules())
        relevant, irrelevant = cd_generic(mods, relevant, irrelevant)
    return relevant, irrelevant
def cd_generic(mods, relevant, irrelevant)

Helper function for cd which loops over modules and propagates them based on the layer name

Expand source code
def cd_generic(mods, relevant, irrelevant):
    '''Helper function for cd which loops over modules and propagates them 
    based on the layer name
    for i, mod in enumerate(mods):
        t = str(type(mod))
        if 'Conv2d' in t:
            relevant, irrelevant = propagate_conv_linear(relevant, irrelevant, mod)
        elif 'Linear' in t:
            relevant = relevant.reshape(relevant.shape[0], -1)
            irrelevant = irrelevant.reshape(irrelevant.shape[0], -1)
            relevant, irrelevant = propagate_conv_linear(relevant, irrelevant, mod)
        elif 'ReLU' in t:
            relevant, irrelevant = propagate_relu(relevant, irrelevant, mod)
        elif 'AvgPool' in t or 'NormLayer' in t or 'Dropout' in t \
                or 'ReshapeLayer' in t or ('modularize' in t and 'Transform' in t):  # custom layers
            relevant, irrelevant = propagate_independent(relevant, irrelevant, mod)
        elif 'Pool' in t and not 'AvgPool' in t:
            relevant, irrelevant = propagate_pooling(relevant, irrelevant, mod)
        elif 'BatchNorm2d' in t:
            relevant, irrelevant = propagate_batchnorm2d(relevant, irrelevant, mod)
    return relevant, irrelevant
def cd_text(batch, model, start, stop, return_irrel_scores=False)

Get contextual decomposition scores for substring of a text sequence


batch: torchtext batch
    really only requires that batch.text is the string input to be interpreted
start: int
    beginning index of substring to be interpreted (inclusive)
stop: int
    ending index of substring to be interpreted (inclusive)


scores: torch.Tensor
    class-wise scores for relevant substring
Expand source code
def cd_text(batch, model, start, stop, return_irrel_scores=False):
    '''Get contextual decomposition scores for substring of a text sequence
        batch: torchtext batch
            really only requires that batch.text is the string input to be interpreted
        start: int
            beginning index of substring to be interpreted (inclusive)
        stop: int
            ending index of substring to be interpreted (inclusive)

        scores: torch.Tensor
            class-wise scores for relevant substring
    weights = model.lstm.state_dict()

    # Index one = word vector (i) or hidden state (h), index two = gate
    W_ii, W_if, W_ig, W_io = np.split(weights['weight_ih_l0'], 4, 0)
    W_hi, W_hf, W_hg, W_ho = np.split(weights['weight_hh_l0'], 4, 0)
    b_i, b_f, b_g, b_o = np.split(weights['bias_ih_l0'].cpu().numpy() + weights['bias_hh_l0'].cpu().numpy(), 4)
    word_vecs = model.embed(batch.text)[:, 0].data
    T = word_vecs.size(0)
    relevant = np.zeros((T, model.hidden_dim))
    irrelevant = np.zeros((T, model.hidden_dim))
    relevant_h = np.zeros((T, model.hidden_dim))
    irrelevant_h = np.zeros((T, model.hidden_dim))
    for i in range(T):
        if i > 0:
            prev_rel_h = relevant_h[i - 1]
            prev_irrel_h = irrelevant_h[i - 1]
            prev_rel_h = np.zeros(model.hidden_dim)
            prev_irrel_h = np.zeros(model.hidden_dim)

        rel_i =, prev_rel_h)
        rel_g =, prev_rel_h)
        rel_f =, prev_rel_h)
        rel_o =, prev_rel_h)
        irrel_i =, prev_irrel_h)
        irrel_g =, prev_irrel_h)
        irrel_f =, prev_irrel_h)
        irrel_o =, prev_irrel_h)

        if i >= start and i <= stop:
            rel_i = rel_i +, word_vecs[i])
            rel_g = rel_g +, word_vecs[i])
            rel_f = rel_f +, word_vecs[i])
            rel_o = rel_o +, word_vecs[i])
            irrel_i = irrel_i +, word_vecs[i])
            irrel_g = irrel_g +, word_vecs[i])
            irrel_f = irrel_f +, word_vecs[i])
            irrel_o = irrel_o +, word_vecs[i])

        rel_contrib_i, irrel_contrib_i, bias_contrib_i = propagate_three(rel_i, irrel_i, b_i, sigmoid)
        rel_contrib_g, irrel_contrib_g, bias_contrib_g = propagate_three(rel_g, irrel_g, b_g, np.tanh)

        relevant[i] = rel_contrib_i * (rel_contrib_g + bias_contrib_g) + bias_contrib_i * rel_contrib_g
        irrelevant[i] = irrel_contrib_i * (rel_contrib_g + irrel_contrib_g + bias_contrib_g) + (
                    rel_contrib_i + bias_contrib_i) * irrel_contrib_g

        if i >= start and i <= stop:
            relevant[i] += bias_contrib_i * bias_contrib_g
            irrelevant[i] += bias_contrib_i * bias_contrib_g

        if i > 0:
            rel_contrib_f, irrel_contrib_f, bias_contrib_f = propagate_three(rel_f, irrel_f, b_f, sigmoid)
            relevant[i] += (rel_contrib_f + bias_contrib_f) * relevant[i - 1]
            irrelevant[i] += (rel_contrib_f + irrel_contrib_f + bias_contrib_f) * irrelevant[i - 1] + irrel_contrib_f * \
                             relevant[i - 1]

        o = sigmoid(, word_vecs[i]) +, prev_rel_h + prev_irrel_h) + b_o)
        rel_contrib_o, irrel_contrib_o, bias_contrib_o = propagate_three(rel_o, irrel_o, b_o, sigmoid)
        new_rel_h, new_irrel_h = propagate_tanh_two(relevant[i], irrelevant[i])
        # relevant_h[i] = new_rel_h * (rel_contrib_o + bias_contrib_o)
        # irrelevant_h[i] = new_rel_h * (irrel_contrib_o) + new_irrel_h * (rel_contrib_o + irrel_contrib_o + bias_contrib_o)
        relevant_h[i] = o * new_rel_h
        irrelevant_h[i] = o * new_irrel_h

    W_out =

    # Sanity check: scores + irrel_scores should equal the LSTM's output minus model.hidden_to_label.bias
    scores =, relevant_h[T - 1])
    irrel_scores =, irrelevant_h[T - 1])

    if return_irrel_scores:
        return scores, irrel_scores

    return scores
def tanh(...)

tanh(input, out=None) -> Tensor

Returns a new tensor with the hyperbolic tangent of the elements of :attr:input.

[ \text{out}{i} = \tanh(\text{input}) ]


input : Tensor
the input tensor.
out : Tensor, optional
the output tensor.


>>> a = torch.randn(4)
>>> a
tensor([ 0.8986, -0.7279,  1.1745,  0.2611])
>>> torch.tanh(a)
tensor([ 0.7156, -0.6218,  0.8257,  0.2553])