Module src.viz
Expand source code
from math import sqrt
import matplotlib.pyplot as plt
import numpy as np
from .config import DIR_FIGS
def savefig(fname):
if '.' in fname:
print('filename should not contain extension!')
if not fname.startswith('fig_'):
fname = 'fig_' + fname
os.makedirs(DIR_FIGS, exist_ok=True)
plt.tight_layout()
plt.savefig(oj(DIR_FIGS, fname) + '.pdf', bbox_inches='tight')
plt.savefig(oj(DIR_FIGS, fname) + '.png', dpi=300, bbox_inches='tight')
def corrplot(corrs):
mask = np.triu(np.ones_like(corrs, dtype=np.bool))
corrs[mask] = np.nan
max_abs = np.nanmax(np.abs(corrs))
plt.imshow(corrs, cmap=style.cmap_div, vmax=max_abs, vmin=-max_abs)
def plot_row(images, annot_list: list = None, dpi: int = 100,
suptitle: str = None, ylab: str = None, fontsize_ylab=25):
'''
Params
------
images: np.ndarray
(num_images, H, W, C)
'''
# deal with inputs
if type(images) == list:
N_IMS = len(images)
else:
N_IMS = images.shape[0]
if annot_list is None:
annot_list = [None] * N_IMS
fig = plt.figure(figsize=(N_IMS * 3, 3), dpi=dpi)
for i in range(N_IMS):
ax = plt.subplot(1, N_IMS, i + 1)
imshow(images[i], annot=annot_list[i])
if i == 0:
show_ylab(ax, ylab, fontsize_ylab=fontsize_ylab)
# plt.ylabel(ylab, fontsize=fontsize_ylab)
# fig.text(0, 0.5, ylab, rotation=90, va='center', fontsize=fontsize_ylab)
if suptitle is not None:
plt.subplot(1, N_IMS, N_IMS // 2 + 1)
plt.title(suptitle)
# if ylab is not None:
plt.tight_layout()
def plot_grid(images, ylabs=[], annot_list=None, suptitle=None, emphasize_col: int = None, fontsize_ylab=25):
'''
Params
------
images: np.ndarray
(R, C, H, W, C)
emphasize_col
which column to emphasize (by not removing black border)
'''
# deal with inputs
if type(images) == list:
images = np.array(images)
# print(images.shape)
# check if wasn't passed a grid
if len(images.shape) == 4:
N_IMS = images.shape[0]
R = int(np.sqrt(N_IMS))
C = R + 1
else:
R = images.shape[0]
C = images.shape[1]
N_IMS = R * C
# reshape to be (R * C, H, W, C)
images = images.reshape((R * C, *images.shape[2:]))
if annot_list is None:
annot_list = [None] * N_IMS
i = 0
fig = plt.figure(figsize=(C * 3, R * 3))
for r in range(R):
for c in range(C):
ax = plt.subplot(R, C, i + 1)
imshow(images[r * C + c], annot=annot_list[i])
if c == 0 and len(ylabs) > r:
show_ylab(ax, ylabs[r], fontsize_ylab=fontsize_ylab)
i += 1
if i >= images.shape[0]:
break
if c == emphasize_col:
emphasize_box(ax)
if suptitle is not None:
fig.text(0.5, 1, suptitle, ha='center')
'''
if ylabs is not None:
for r in range(R):
fig.text(0, r / R + 0.5 / R, ylabs[R - 1 - r], rotation=90,
va='center', fontsize=fontsize_ylab)
'''
fig.tight_layout()
def show_ylab(ax, ylab, fontsize_ylab):
plt.axis('on')
ax.get_yaxis().set_ticks([])
ax.get_xaxis().set_ticks([])
for x in ['right', 'top', 'bottom', 'left']:
ax.spines[x].set_visible(False)
plt.ylabel(ylab, fontsize=fontsize_ylab)
def emphasize_box(ax):
plt.axis('on')
ax.get_yaxis().set_ticks([])
ax.get_xaxis().set_ticks([])
for x in ['right', 'top', 'bottom', 'left']:
ax.spines[x].set_visible(True)
ax.spines[x].set_linewidth(3) # ['linewidth'] = 10
# [i.set_linewidth(0.1) for i in ax.spines.itervalues()]
# ax.spines['top'].set_visible(True)
def norm(im):
'''Normalize to [0, 1]
'''
return (im - np.min(im)) / (np.max(im) - np.min(im)) # converts range to [0, 1]
def imshow(im, annot: str = None):
'''
Params
------
annot
str to put in top-right corner
'''
# if 4d, take first image
if len(im.shape) > 3:
im = im[0]
# if channels dimension first, transpose
if im.shape[0] == 3 and len(im.shape) == 3:
im = im.transpose()
ax = plt.gca()
ax.imshow(im)
ax.axis('off')
if annot is not None:
padding = 5
ax.annotate(
s=annot,
fontsize=12,
xy=(0, 0),
xytext=(padding - 1, -(padding - 1)),
textcoords='offset pixels',
bbox=dict(facecolor='white', alpha=1, pad=padding),
va='top',
ha='left')
def detach(tensor):
return tensor.detach().cpu().numpy()
Functions
def corrplot(corrs)
-
Expand source code
def corrplot(corrs): mask = np.triu(np.ones_like(corrs, dtype=np.bool)) corrs[mask] = np.nan max_abs = np.nanmax(np.abs(corrs)) plt.imshow(corrs, cmap=style.cmap_div, vmax=max_abs, vmin=-max_abs)
def detach(tensor)
-
Expand source code
def detach(tensor): return tensor.detach().cpu().numpy()
def emphasize_box(ax)
-
Expand source code
def emphasize_box(ax): plt.axis('on') ax.get_yaxis().set_ticks([]) ax.get_xaxis().set_ticks([]) for x in ['right', 'top', 'bottom', 'left']: ax.spines[x].set_visible(True) ax.spines[x].set_linewidth(3) # ['linewidth'] = 10
def imshow(im, annot=None)
-
Params
annot
- str to put in top-right corner
Expand source code
def imshow(im, annot: str = None): ''' Params ------ annot str to put in top-right corner ''' # if 4d, take first image if len(im.shape) > 3: im = im[0] # if channels dimension first, transpose if im.shape[0] == 3 and len(im.shape) == 3: im = im.transpose() ax = plt.gca() ax.imshow(im) ax.axis('off') if annot is not None: padding = 5 ax.annotate( s=annot, fontsize=12, xy=(0, 0), xytext=(padding - 1, -(padding - 1)), textcoords='offset pixels', bbox=dict(facecolor='white', alpha=1, pad=padding), va='top', ha='left')
def norm(im)
-
Normalize to [0, 1]
Expand source code
def norm(im): '''Normalize to [0, 1] ''' return (im - np.min(im)) / (np.max(im) - np.min(im)) # converts range to [0, 1]
def plot_grid(images, ylabs=[], annot_list=None, suptitle=None, emphasize_col=None, fontsize_ylab=25)
-
Params
images
:np.ndarray
- (R, C, H, W, C)
emphasize_col
- which column to emphasize (by not removing black border)
Expand source code
def plot_grid(images, ylabs=[], annot_list=None, suptitle=None, emphasize_col: int = None, fontsize_ylab=25): ''' Params ------ images: np.ndarray (R, C, H, W, C) emphasize_col which column to emphasize (by not removing black border) ''' # deal with inputs if type(images) == list: images = np.array(images) # print(images.shape) # check if wasn't passed a grid if len(images.shape) == 4: N_IMS = images.shape[0] R = int(np.sqrt(N_IMS)) C = R + 1 else: R = images.shape[0] C = images.shape[1] N_IMS = R * C # reshape to be (R * C, H, W, C) images = images.reshape((R * C, *images.shape[2:])) if annot_list is None: annot_list = [None] * N_IMS i = 0 fig = plt.figure(figsize=(C * 3, R * 3)) for r in range(R): for c in range(C): ax = plt.subplot(R, C, i + 1) imshow(images[r * C + c], annot=annot_list[i]) if c == 0 and len(ylabs) > r: show_ylab(ax, ylabs[r], fontsize_ylab=fontsize_ylab) i += 1 if i >= images.shape[0]: break if c == emphasize_col: emphasize_box(ax) if suptitle is not None: fig.text(0.5, 1, suptitle, ha='center') ''' if ylabs is not None: for r in range(R): fig.text(0, r / R + 0.5 / R, ylabs[R - 1 - r], rotation=90, va='center', fontsize=fontsize_ylab) ''' fig.tight_layout()
def plot_row(images, annot_list=None, dpi=100, suptitle=None, ylab=None, fontsize_ylab=25)
-
Params
images
:np.ndarray
- (num_images, H, W, C)
Expand source code
def plot_row(images, annot_list: list = None, dpi: int = 100, suptitle: str = None, ylab: str = None, fontsize_ylab=25): ''' Params ------ images: np.ndarray (num_images, H, W, C) ''' # deal with inputs if type(images) == list: N_IMS = len(images) else: N_IMS = images.shape[0] if annot_list is None: annot_list = [None] * N_IMS fig = plt.figure(figsize=(N_IMS * 3, 3), dpi=dpi) for i in range(N_IMS): ax = plt.subplot(1, N_IMS, i + 1) imshow(images[i], annot=annot_list[i]) if i == 0: show_ylab(ax, ylab, fontsize_ylab=fontsize_ylab) # plt.ylabel(ylab, fontsize=fontsize_ylab) # fig.text(0, 0.5, ylab, rotation=90, va='center', fontsize=fontsize_ylab) if suptitle is not None: plt.subplot(1, N_IMS, N_IMS // 2 + 1) plt.title(suptitle) # if ylab is not None: plt.tight_layout()
def savefig(fname)
-
Expand source code
def savefig(fname): if '.' in fname: print('filename should not contain extension!') if not fname.startswith('fig_'): fname = 'fig_' + fname os.makedirs(DIR_FIGS, exist_ok=True) plt.tight_layout() plt.savefig(oj(DIR_FIGS, fname) + '.pdf', bbox_inches='tight') plt.savefig(oj(DIR_FIGS, fname) + '.png', dpi=300, bbox_inches='tight')
def show_ylab(ax, ylab, fontsize_ylab)
-
Expand source code
def show_ylab(ax, ylab, fontsize_ylab): plt.axis('on') ax.get_yaxis().set_ticks([]) ax.get_xaxis().set_ticks([]) for x in ['right', 'top', 'bottom', 'left']: ax.spines[x].set_visible(False) plt.ylabel(ylab, fontsize=fontsize_ylab)