Module acd.scores.cd_propagate
Expand source code
from copy import deepcopy
import numpy as np
import torch
from scipy.special import expit as sigmoid
from torch import tanh
def propagate_conv_linear(relevant, irrelevant, module):
'''Propagate convolutional or linear layer
Apply linear part to both pieces
Split bias based on the ratio of the absolute sums
'''
device = relevant.device
bias = module(torch.zeros(irrelevant.size()).to(device))
rel = module(relevant) - bias
irrel = module(irrelevant) - bias
# elementwise proportional
prop_rel = torch.abs(rel) + 1e-20 # add a small constant so we don't divide by 0
prop_irrel = torch.abs(irrel) + 1e-20 # add a small constant so we don't divide by 0
prop_sum = prop_rel + prop_irrel
prop_rel = torch.div(prop_rel, prop_sum)
prop_irrel = torch.div(prop_irrel, prop_sum)
return rel + torch.mul(prop_rel, bias), irrel + torch.mul(prop_irrel, bias)
def propagate_batchnorm2d(relevant, irrelevant, module):
'''Propagate batchnorm2d operation
'''
device = relevant.device
bias = module(torch.zeros(irrelevant.size()).to(device))
rel = module(relevant) - bias
irrel = module(irrelevant) - bias
prop_rel = torch.abs(rel)
prop_irrel = torch.abs(irrel)
prop_sum = prop_rel + prop_irrel
prop_rel = torch.div(prop_rel, prop_sum)
prop_rel[torch.isnan(prop_rel)] = 0
rel = rel + torch.mul(prop_rel, bias)
irrel = module(relevant + irrelevant) - rel
return rel, irrel
def propagate_pooling(relevant, irrelevant, pooler):
'''propagate pooling operation
'''
# get both indices
p = deepcopy(pooler)
p.return_indices = True
both, both_ind = p(relevant + irrelevant)
# unpooling function
def unpool(tensor, indices):
'''Unpool tensor given indices for pooling
'''
batch_size, in_channels, H, W = indices.shape
output = torch.ones_like(indices, dtype=torch.float)
for i in range(batch_size):
for j in range(in_channels):
output[i, j] = tensor[i, j].flatten()[indices[i, j].flatten()].reshape(H, W)
return output
rel, irrel = unpool(relevant, both_ind), unpool(irrelevant, both_ind)
return rel, irrel
def propagate_independent(relevant, irrelevant, module):
'''use for things which operate independently
ex. avgpool, layer_norm, dropout
'''
return module(relevant), module(irrelevant)
def propagate_relu(relevant, irrelevant, activation):
'''propagate ReLu nonlinearity
'''
swap_inplace = False
try: # handles inplace
if activation.inplace:
swap_inplace = True
activation.inplace = False
except:
pass
rel_score = activation(relevant)
irrel_score = activation(relevant + irrelevant) - activation(relevant)
if swap_inplace:
activation.inplace = True
return rel_score, irrel_score
def propagate_three(a, b, c, activation):
'''Propagate a three-part nonlinearity
'''
a_contrib = 0.5 * (activation(a + c) - activation(c) + activation(a + b + c) - activation(b + c))
b_contrib = 0.5 * (activation(b + c) - activation(c) + activation(a + b + c) - activation(a + c))
return a_contrib, b_contrib, activation(c)
def propagate_tanh_two(a, b):
'''propagate tanh nonlinearity
'''
return 0.5 * (np.tanh(a) + (np.tanh(a + b) - np.tanh(b))), 0.5 * (np.tanh(b) + (np.tanh(a + b) - np.tanh(a)))
def propagate_basic_block(rel, irrel, module):
'''Propagate a BasicBlock (used in the ResNet architectures)
This is what the forward pass of the basic block looks like
identity = x
out = self.conv1(x) # 1
out = self.bn1(out) # 2
out = self.relu(out) # 3
out = self.conv2(out) # 4
out = self.bn2(out) # 5
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
'''
from .cd import cd_generic
# for mod in module.modules():
# print('\tm', mod)
rel_identity, irrel_identity = deepcopy(rel), deepcopy(irrel)
rel, irrel = cd_generic(list(module.modules())[1:6], rel, irrel)
if module.downsample is not None:
rel_identity, irrel_identity = cd_generic(module.downsample.modules(), rel_identity, irrel_identity)
rel += rel_identity
irrel += irrel_identity
rel, irrel = propagate_relu(rel, irrel, module.relu)
return rel, irrel
def propagate_lstm(x, module, start: int, stop: int, my_device=0):
'''module is an lstm layer
Params
------
module: lstm layer
x: torch.Tensor
(batch_size, seq_len, num_channels)
warning: default lstm uses shape (seq_len, batch_size, num_channels)
start: int
start of relevant sequence
stop: int
end of relevant sequence
Returns
-------
rel, irrel: torch.Tensor
(batch_size, num_channels, num_hidden_lstm)
'''
# extract out weights
W_ii, W_if, W_ig, W_io = torch.chunk(module.weight_ih_l0, 4, 0)
W_hi, W_hf, W_hg, W_ho = torch.chunk(module.weight_hh_l0, 4, 0)
b_i, b_f, b_g, b_o = torch.chunk(module.bias_ih_l0 + module.bias_hh_l0, 4)
# prepare input x
# x_orig = deepcopy(x)
x = x.permute(1, 2, 0) # convert to (seq_len, num_channels, batch_size)
seq_len = x.shape[0]
batch_size = x.shape[2]
output_dim = W_ho.shape[1]
relevant_h = torch.zeros((output_dim, batch_size), device=torch.device(my_device), requires_grad=False)
irrelevant_h = torch.zeros((output_dim, batch_size), device=torch.device(my_device), requires_grad=False)
prev_rel = torch.zeros((output_dim, batch_size), device=torch.device(my_device), requires_grad=False)
prev_irrel = torch.zeros((output_dim, batch_size), device=torch.device(my_device), requires_grad=False)
for i in range(seq_len):
prev_rel_h = relevant_h
prev_irrel_h = irrelevant_h
rel_i = torch.matmul(W_hi, prev_rel_h)
rel_g = torch.matmul(W_hg, prev_rel_h)
rel_f = torch.matmul(W_hf, prev_rel_h)
rel_o = torch.matmul(W_ho, prev_rel_h)
irrel_i = torch.matmul(W_hi, prev_irrel_h)
irrel_g = torch.matmul(W_hg, prev_irrel_h)
irrel_f = torch.matmul(W_hf, prev_irrel_h)
irrel_o = torch.matmul(W_ho, prev_irrel_h)
if i >= start and i <= stop:
rel_i = rel_i + torch.matmul(W_ii, x[i])
rel_g = rel_g + torch.matmul(W_ig, x[i])
rel_f = rel_f + torch.matmul(W_if, x[i])
# rel_o = rel_o + torch.matmul(W_io, x[i])
else:
irrel_i = irrel_i + torch.matmul(W_ii, x[i])
irrel_g = irrel_g + torch.matmul(W_ig, x[i])
irrel_f = irrel_f + torch.matmul(W_if, x[i])
# irrel_o = irrel_o + torch.matmul(W_io, x[i])
rel_contrib_i, irrel_contrib_i, bias_contrib_i = propagate_three(rel_i, irrel_i, b_i[:, None], sigmoid)
rel_contrib_g, irrel_contrib_g, bias_contrib_g = propagate_three(rel_g, irrel_g, b_g[:, None], tanh)
relevant = rel_contrib_i * (rel_contrib_g + bias_contrib_g) + bias_contrib_i * rel_contrib_g
irrelevant = irrel_contrib_i * (rel_contrib_g + irrel_contrib_g + bias_contrib_g) + (
rel_contrib_i + bias_contrib_i) * irrel_contrib_g
if i >= start and i < stop:
relevant = relevant + bias_contrib_i * bias_contrib_g
else:
irrelevant = irrelevant + bias_contrib_i * bias_contrib_g
if i > 0:
rel_contrib_f, irrel_contrib_f, bias_contrib_f = propagate_three(rel_f, irrel_f, b_f[:, None], sigmoid)
relevant = relevant + (rel_contrib_f + bias_contrib_f) * prev_rel
irrelevant = irrelevant + (
rel_contrib_f + irrel_contrib_f + bias_contrib_f) * prev_irrel + irrel_contrib_f * prev_rel
o = sigmoid(torch.matmul(W_io, x[i]) + torch.matmul(W_ho, prev_rel_h + prev_irrel_h) + b_o[:, None])
new_rel_h, new_irrel_h = propagate_tanh_two(relevant, irrelevant)
relevant_h = o * new_rel_h
irrelevant_h = o * new_irrel_h
prev_rel = relevant
prev_irrel = irrelevant
# outputs, (h1, c1) = module(x_orig)
# assert np.allclose((relevant_h + irrelevant_h).detach().numpy().flatten(),
# h1.detach().numpy().flatten(), rtol=0.01)
# reshape output
rel_h = relevant_h.transpose(0, 1).unsqueeze(1)
irrel_h = irrelevant_h.transpose(0, 1).unsqueeze(1)
return rel_h, irrel_h
def propagate_lstm_block(x_rel, x_irrel, module, start: int, stop: int, my_device=0):
'''module is an lstm layer. This function still experimental
Params
------
module: lstm layer
x_rel: torch.Tensor
(batch_size, seq_len, num_channels)
warning: default lstm uses shape (seq_len, batch_size, num_channels)
x_irrel: torch.Tensor
(batch_size, seq_len, num_channels)
start: int
start of relevant sequence
stop: int
end of relevant sequence
weights: torch.Tensor
(seq_len)
Returns
-------
rel, irrel: torch.Tensor
(batch_size, num_channels, num_hidden_lstm)
'''
# ex_reltract out weights
W_ii, W_if, W_ig, W_io = torch.chunk(module.weight_ih_l0, 4, 0)
W_hi, W_hf, W_hg, W_ho = torch.chunk(module.weight_hh_l0, 4, 0)
b_i, b_f, b_g, b_o = torch.chunk(module.bias_ih_l0 + module.bias_hh_l0, 4)
# prepare input x
# x_orig = deepcopy(x)
x_rel = x_rel.permute(1, 2, 0) # convert to (seq_len, num_channels, batch_size)
x_irrel = x_irrel.permute(1, 2, 0) # convert to (seq_len, num_channels, batch_size)
x = x_rel + x_irrel
# print('shapes', x_rel.shape, x_irrel.shape, x.shape)
seq_len = x_rel.shape[0]
batch_size = x_rel.shape[2]
output_dim = W_ho.shape[1]
relevant_h = torch.zeros((output_dim, batch_size), device=torch.device(my_device), requires_grad=False)
irrelevant_h = torch.zeros((output_dim, batch_size), device=torch.device(my_device), requires_grad=False)
prev_rel = torch.zeros((output_dim, batch_size), device=torch.device(my_device), requires_grad=False)
prev_irrel = torch.zeros((output_dim, batch_size), device=torch.device(my_device), requires_grad=False)
for i in range(seq_len):
prev_rel_h = relevant_h
prev_irrel_h = irrelevant_h
rel_i = torch.matmul(W_hi, prev_rel_h)
rel_g = torch.matmul(W_hg, prev_rel_h)
rel_f = torch.matmul(W_hf, prev_rel_h)
rel_o = torch.matmul(W_ho, prev_rel_h)
irrel_i = torch.matmul(W_hi, prev_irrel_h)
irrel_g = torch.matmul(W_hg, prev_irrel_h)
irrel_f = torch.matmul(W_hf, prev_irrel_h)
irrel_o = torch.matmul(W_ho, prev_irrel_h)
# relevant parts
rel_i = rel_i + torch.matmul(W_ii, x_rel[i])
rel_g = rel_g + torch.matmul(W_ig, x_rel[i])
rel_f = rel_f + torch.matmul(W_if, x_rel[i])
# rel_o = rel_o + torch.matmul(W_io, x[i])
# irrelevant parts
irrel_i = irrel_i + torch.matmul(W_ii, x_irrel[i])
irrel_g = irrel_g + torch.matmul(W_ig, x_irrel[i])
irrel_f = irrel_f + torch.matmul(W_if, x_irrel[i])
# irrel_o = irrel_o + torch.matmul(W_io, x[i])
rel_contrib_i, irrel_contrib_i, bias_contrib_i = propagate_three(rel_i, irrel_i, b_i[:, None], sigmoid)
rel_contrib_g, irrel_contrib_g, bias_contrib_g = propagate_three(rel_g, irrel_g, b_g[:, None], tanh)
relevant = rel_contrib_i * (rel_contrib_g + bias_contrib_g) + \
bias_contrib_i * rel_contrib_g
irrelevant = irrel_contrib_i * (rel_contrib_g + irrel_contrib_g + bias_contrib_g) + \
(rel_contrib_i + bias_contrib_i) * irrel_contrib_g
# if i >= start and i < stop:
relevant = relevant + bias_contrib_i * bias_contrib_g
# else:
irrelevant = irrelevant + bias_contrib_i * bias_contrib_g
if i > 0:
rel_contrib_f, irrel_contrib_f, bias_contrib_f = propagate_three(rel_f, irrel_f, b_f[:, None], sigmoid)
relevant = relevant + (rel_contrib_f + bias_contrib_f) * prev_rel
irrelevant = irrelevant + (
rel_contrib_f + irrel_contrib_f + bias_contrib_f) * prev_irrel + irrel_contrib_f * prev_rel
o = sigmoid(torch.matmul(W_io, x[i]) + torch.matmul(W_ho, prev_rel_h + prev_irrel_h) + b_o[:, None])
new_rel_h, new_irrel_h = propagate_tanh_two(relevant, irrelevant)
relevant_h = o * new_rel_h
irrelevant_h = o * new_irrel_h
prev_rel = relevant
prev_irrel = irrelevant
# outputs, (h1, c1) = module(x_orig)
# assert np.allclose((relevant_h + irrelevant_h).detach().numpy().flatten(),
# h1.detach().numpy().flatten(), rtol=0.01)
# reshape output
rel_h = relevant_h.transpose(0, 1).unsqueeze(1)
irrel_h = irrelevant_h.transpose(0, 1).unsqueeze(1)
return rel_h, irrel_h
Functions
def propagate_basic_block(rel, irrel, module)
-
Propagate a BasicBlock (used in the ResNet architectures) This is what the forward pass of the basic block looks like identity = x
out = self.conv1(x) # 1 out = self.bn1(out) # 2 out = self.relu(out) # 3 out = self.conv2(out) # 4 out = self.bn2(out) # 5
if self.downsample is not None: identity = self.downsample(x)
out += identity out = self.relu(out)
Expand source code
def propagate_basic_block(rel, irrel, module): '''Propagate a BasicBlock (used in the ResNet architectures) This is what the forward pass of the basic block looks like identity = x out = self.conv1(x) # 1 out = self.bn1(out) # 2 out = self.relu(out) # 3 out = self.conv2(out) # 4 out = self.bn2(out) # 5 if self.downsample is not None: identity = self.downsample(x) out += identity out = self.relu(out) ''' from .cd import cd_generic # for mod in module.modules(): # print('\tm', mod) rel_identity, irrel_identity = deepcopy(rel), deepcopy(irrel) rel, irrel = cd_generic(list(module.modules())[1:6], rel, irrel) if module.downsample is not None: rel_identity, irrel_identity = cd_generic(module.downsample.modules(), rel_identity, irrel_identity) rel += rel_identity irrel += irrel_identity rel, irrel = propagate_relu(rel, irrel, module.relu) return rel, irrel
def propagate_batchnorm2d(relevant, irrelevant, module)
-
Propagate batchnorm2d operation
Expand source code
def propagate_batchnorm2d(relevant, irrelevant, module): '''Propagate batchnorm2d operation ''' device = relevant.device bias = module(torch.zeros(irrelevant.size()).to(device)) rel = module(relevant) - bias irrel = module(irrelevant) - bias prop_rel = torch.abs(rel) prop_irrel = torch.abs(irrel) prop_sum = prop_rel + prop_irrel prop_rel = torch.div(prop_rel, prop_sum) prop_rel[torch.isnan(prop_rel)] = 0 rel = rel + torch.mul(prop_rel, bias) irrel = module(relevant + irrelevant) - rel return rel, irrel
def propagate_conv_linear(relevant, irrelevant, module)
-
Propagate convolutional or linear layer Apply linear part to both pieces Split bias based on the ratio of the absolute sums
Expand source code
def propagate_conv_linear(relevant, irrelevant, module): '''Propagate convolutional or linear layer Apply linear part to both pieces Split bias based on the ratio of the absolute sums ''' device = relevant.device bias = module(torch.zeros(irrelevant.size()).to(device)) rel = module(relevant) - bias irrel = module(irrelevant) - bias # elementwise proportional prop_rel = torch.abs(rel) + 1e-20 # add a small constant so we don't divide by 0 prop_irrel = torch.abs(irrel) + 1e-20 # add a small constant so we don't divide by 0 prop_sum = prop_rel + prop_irrel prop_rel = torch.div(prop_rel, prop_sum) prop_irrel = torch.div(prop_irrel, prop_sum) return rel + torch.mul(prop_rel, bias), irrel + torch.mul(prop_irrel, bias)
def propagate_independent(relevant, irrelevant, module)
-
use for things which operate independently ex. avgpool, layer_norm, dropout
Expand source code
def propagate_independent(relevant, irrelevant, module): '''use for things which operate independently ex. avgpool, layer_norm, dropout ''' return module(relevant), module(irrelevant)
def propagate_lstm(x, module, start, stop, my_device=0)
-
module is an lstm layer
Params
module
:lstm
layer
x
:torch.Tensor
- (batch_size, seq_len, num_channels) warning: default lstm uses shape (seq_len, batch_size, num_channels)
start
:int
- start of relevant sequence
stop
:int
- end of relevant sequence
Returns
rel
,irrel
:torch.Tensor
- (batch_size, num_channels, num_hidden_lstm)
Expand source code
def propagate_lstm(x, module, start: int, stop: int, my_device=0): '''module is an lstm layer Params ------ module: lstm layer x: torch.Tensor (batch_size, seq_len, num_channels) warning: default lstm uses shape (seq_len, batch_size, num_channels) start: int start of relevant sequence stop: int end of relevant sequence Returns ------- rel, irrel: torch.Tensor (batch_size, num_channels, num_hidden_lstm) ''' # extract out weights W_ii, W_if, W_ig, W_io = torch.chunk(module.weight_ih_l0, 4, 0) W_hi, W_hf, W_hg, W_ho = torch.chunk(module.weight_hh_l0, 4, 0) b_i, b_f, b_g, b_o = torch.chunk(module.bias_ih_l0 + module.bias_hh_l0, 4) # prepare input x # x_orig = deepcopy(x) x = x.permute(1, 2, 0) # convert to (seq_len, num_channels, batch_size) seq_len = x.shape[0] batch_size = x.shape[2] output_dim = W_ho.shape[1] relevant_h = torch.zeros((output_dim, batch_size), device=torch.device(my_device), requires_grad=False) irrelevant_h = torch.zeros((output_dim, batch_size), device=torch.device(my_device), requires_grad=False) prev_rel = torch.zeros((output_dim, batch_size), device=torch.device(my_device), requires_grad=False) prev_irrel = torch.zeros((output_dim, batch_size), device=torch.device(my_device), requires_grad=False) for i in range(seq_len): prev_rel_h = relevant_h prev_irrel_h = irrelevant_h rel_i = torch.matmul(W_hi, prev_rel_h) rel_g = torch.matmul(W_hg, prev_rel_h) rel_f = torch.matmul(W_hf, prev_rel_h) rel_o = torch.matmul(W_ho, prev_rel_h) irrel_i = torch.matmul(W_hi, prev_irrel_h) irrel_g = torch.matmul(W_hg, prev_irrel_h) irrel_f = torch.matmul(W_hf, prev_irrel_h) irrel_o = torch.matmul(W_ho, prev_irrel_h) if i >= start and i <= stop: rel_i = rel_i + torch.matmul(W_ii, x[i]) rel_g = rel_g + torch.matmul(W_ig, x[i]) rel_f = rel_f + torch.matmul(W_if, x[i]) # rel_o = rel_o + torch.matmul(W_io, x[i]) else: irrel_i = irrel_i + torch.matmul(W_ii, x[i]) irrel_g = irrel_g + torch.matmul(W_ig, x[i]) irrel_f = irrel_f + torch.matmul(W_if, x[i]) # irrel_o = irrel_o + torch.matmul(W_io, x[i]) rel_contrib_i, irrel_contrib_i, bias_contrib_i = propagate_three(rel_i, irrel_i, b_i[:, None], sigmoid) rel_contrib_g, irrel_contrib_g, bias_contrib_g = propagate_three(rel_g, irrel_g, b_g[:, None], tanh) relevant = rel_contrib_i * (rel_contrib_g + bias_contrib_g) + bias_contrib_i * rel_contrib_g irrelevant = irrel_contrib_i * (rel_contrib_g + irrel_contrib_g + bias_contrib_g) + ( rel_contrib_i + bias_contrib_i) * irrel_contrib_g if i >= start and i < stop: relevant = relevant + bias_contrib_i * bias_contrib_g else: irrelevant = irrelevant + bias_contrib_i * bias_contrib_g if i > 0: rel_contrib_f, irrel_contrib_f, bias_contrib_f = propagate_three(rel_f, irrel_f, b_f[:, None], sigmoid) relevant = relevant + (rel_contrib_f + bias_contrib_f) * prev_rel irrelevant = irrelevant + ( rel_contrib_f + irrel_contrib_f + bias_contrib_f) * prev_irrel + irrel_contrib_f * prev_rel o = sigmoid(torch.matmul(W_io, x[i]) + torch.matmul(W_ho, prev_rel_h + prev_irrel_h) + b_o[:, None]) new_rel_h, new_irrel_h = propagate_tanh_two(relevant, irrelevant) relevant_h = o * new_rel_h irrelevant_h = o * new_irrel_h prev_rel = relevant prev_irrel = irrelevant # outputs, (h1, c1) = module(x_orig) # assert np.allclose((relevant_h + irrelevant_h).detach().numpy().flatten(), # h1.detach().numpy().flatten(), rtol=0.01) # reshape output rel_h = relevant_h.transpose(0, 1).unsqueeze(1) irrel_h = irrelevant_h.transpose(0, 1).unsqueeze(1) return rel_h, irrel_h
def propagate_lstm_block(x_rel, x_irrel, module, start, stop, my_device=0)
-
module is an lstm layer. This function still experimental
Params
module
:lstm
layer
x_rel
:torch.Tensor
- (batch_size, seq_len, num_channels) warning: default lstm uses shape (seq_len, batch_size, num_channels)
x_irrel
:torch.Tensor
- (batch_size, seq_len, num_channels)
start
:int
- start of relevant sequence
stop
:int
- end of relevant sequence
weights
:torch.Tensor
- (seq_len)
Returns
rel
,irrel
:torch.Tensor
- (batch_size, num_channels, num_hidden_lstm)
Expand source code
def propagate_lstm_block(x_rel, x_irrel, module, start: int, stop: int, my_device=0): '''module is an lstm layer. This function still experimental Params ------ module: lstm layer x_rel: torch.Tensor (batch_size, seq_len, num_channels) warning: default lstm uses shape (seq_len, batch_size, num_channels) x_irrel: torch.Tensor (batch_size, seq_len, num_channels) start: int start of relevant sequence stop: int end of relevant sequence weights: torch.Tensor (seq_len) Returns ------- rel, irrel: torch.Tensor (batch_size, num_channels, num_hidden_lstm) ''' # ex_reltract out weights W_ii, W_if, W_ig, W_io = torch.chunk(module.weight_ih_l0, 4, 0) W_hi, W_hf, W_hg, W_ho = torch.chunk(module.weight_hh_l0, 4, 0) b_i, b_f, b_g, b_o = torch.chunk(module.bias_ih_l0 + module.bias_hh_l0, 4) # prepare input x # x_orig = deepcopy(x) x_rel = x_rel.permute(1, 2, 0) # convert to (seq_len, num_channels, batch_size) x_irrel = x_irrel.permute(1, 2, 0) # convert to (seq_len, num_channels, batch_size) x = x_rel + x_irrel # print('shapes', x_rel.shape, x_irrel.shape, x.shape) seq_len = x_rel.shape[0] batch_size = x_rel.shape[2] output_dim = W_ho.shape[1] relevant_h = torch.zeros((output_dim, batch_size), device=torch.device(my_device), requires_grad=False) irrelevant_h = torch.zeros((output_dim, batch_size), device=torch.device(my_device), requires_grad=False) prev_rel = torch.zeros((output_dim, batch_size), device=torch.device(my_device), requires_grad=False) prev_irrel = torch.zeros((output_dim, batch_size), device=torch.device(my_device), requires_grad=False) for i in range(seq_len): prev_rel_h = relevant_h prev_irrel_h = irrelevant_h rel_i = torch.matmul(W_hi, prev_rel_h) rel_g = torch.matmul(W_hg, prev_rel_h) rel_f = torch.matmul(W_hf, prev_rel_h) rel_o = torch.matmul(W_ho, prev_rel_h) irrel_i = torch.matmul(W_hi, prev_irrel_h) irrel_g = torch.matmul(W_hg, prev_irrel_h) irrel_f = torch.matmul(W_hf, prev_irrel_h) irrel_o = torch.matmul(W_ho, prev_irrel_h) # relevant parts rel_i = rel_i + torch.matmul(W_ii, x_rel[i]) rel_g = rel_g + torch.matmul(W_ig, x_rel[i]) rel_f = rel_f + torch.matmul(W_if, x_rel[i]) # rel_o = rel_o + torch.matmul(W_io, x[i]) # irrelevant parts irrel_i = irrel_i + torch.matmul(W_ii, x_irrel[i]) irrel_g = irrel_g + torch.matmul(W_ig, x_irrel[i]) irrel_f = irrel_f + torch.matmul(W_if, x_irrel[i]) # irrel_o = irrel_o + torch.matmul(W_io, x[i]) rel_contrib_i, irrel_contrib_i, bias_contrib_i = propagate_three(rel_i, irrel_i, b_i[:, None], sigmoid) rel_contrib_g, irrel_contrib_g, bias_contrib_g = propagate_three(rel_g, irrel_g, b_g[:, None], tanh) relevant = rel_contrib_i * (rel_contrib_g + bias_contrib_g) + \ bias_contrib_i * rel_contrib_g irrelevant = irrel_contrib_i * (rel_contrib_g + irrel_contrib_g + bias_contrib_g) + \ (rel_contrib_i + bias_contrib_i) * irrel_contrib_g # if i >= start and i < stop: relevant = relevant + bias_contrib_i * bias_contrib_g # else: irrelevant = irrelevant + bias_contrib_i * bias_contrib_g if i > 0: rel_contrib_f, irrel_contrib_f, bias_contrib_f = propagate_three(rel_f, irrel_f, b_f[:, None], sigmoid) relevant = relevant + (rel_contrib_f + bias_contrib_f) * prev_rel irrelevant = irrelevant + ( rel_contrib_f + irrel_contrib_f + bias_contrib_f) * prev_irrel + irrel_contrib_f * prev_rel o = sigmoid(torch.matmul(W_io, x[i]) + torch.matmul(W_ho, prev_rel_h + prev_irrel_h) + b_o[:, None]) new_rel_h, new_irrel_h = propagate_tanh_two(relevant, irrelevant) relevant_h = o * new_rel_h irrelevant_h = o * new_irrel_h prev_rel = relevant prev_irrel = irrelevant # outputs, (h1, c1) = module(x_orig) # assert np.allclose((relevant_h + irrelevant_h).detach().numpy().flatten(), # h1.detach().numpy().flatten(), rtol=0.01) # reshape output rel_h = relevant_h.transpose(0, 1).unsqueeze(1) irrel_h = irrelevant_h.transpose(0, 1).unsqueeze(1) return rel_h, irrel_h
def propagate_pooling(relevant, irrelevant, pooler)
-
propagate pooling operation
Expand source code
def propagate_pooling(relevant, irrelevant, pooler): '''propagate pooling operation ''' # get both indices p = deepcopy(pooler) p.return_indices = True both, both_ind = p(relevant + irrelevant) # unpooling function def unpool(tensor, indices): '''Unpool tensor given indices for pooling ''' batch_size, in_channels, H, W = indices.shape output = torch.ones_like(indices, dtype=torch.float) for i in range(batch_size): for j in range(in_channels): output[i, j] = tensor[i, j].flatten()[indices[i, j].flatten()].reshape(H, W) return output rel, irrel = unpool(relevant, both_ind), unpool(irrelevant, both_ind) return rel, irrel
def propagate_relu(relevant, irrelevant, activation)
-
propagate ReLu nonlinearity
Expand source code
def propagate_relu(relevant, irrelevant, activation): '''propagate ReLu nonlinearity ''' swap_inplace = False try: # handles inplace if activation.inplace: swap_inplace = True activation.inplace = False except: pass rel_score = activation(relevant) irrel_score = activation(relevant + irrelevant) - activation(relevant) if swap_inplace: activation.inplace = True return rel_score, irrel_score
def propagate_tanh_two(a, b)
-
propagate tanh nonlinearity
Expand source code
def propagate_tanh_two(a, b): '''propagate tanh nonlinearity ''' return 0.5 * (np.tanh(a) + (np.tanh(a + b) - np.tanh(b))), 0.5 * (np.tanh(b) + (np.tanh(a + b) - np.tanh(a)))
def propagate_three(a, b, c, activation)
-
Propagate a three-part nonlinearity
Expand source code
def propagate_three(a, b, c, activation): '''Propagate a three-part nonlinearity ''' a_contrib = 0.5 * (activation(a + c) - activation(c) + activation(a + b + c) - activation(b + c)) b_contrib = 0.5 * (activation(b + c) - activation(c) + activation(a + b + c) - activation(a + c)) return a_contrib, b_contrib, activation(c)
def tanh(...)
-
tanh(input, out=None) -> Tensor
Returns a new tensor with the hyperbolic tangent of the elements of :attr:
input
.[ \text{out}{i} = \tanh(\text{input}) ]
Args
input
:Tensor
- the input tensor.
out
:Tensor
, optional- the output tensor.
Example::
>>> a = torch.randn(4) >>> a tensor([ 0.8986, -0.7279, 1.1745, 0.2611]) >>> torch.tanh(a) tensor([ 0.7156, -0.6218, 0.8257, 0.2553])