fast-mri/scripts/7.Visualize_saliency.py

60 lines
2.3 KiB
Python
Executable File

import argparse
from email.mime import image
import numpy as np
import matplotlib.pyplot as plt
import SimpleITK as sitk
parser = argparse.ArgumentParser(
description='Calculate the froc metrics and store in froc_metrics.yml')
parser.add_argument('-experiment',
help='Title of experiment')
parser.add_argument('--series', '-s',
metavar='[series_name]', required=True, nargs='+',
help='List of series to include')
args = parser.parse_args()
########## constants #################
SERIES = args.series
series_ = '_'.join(args.series)
EXPERIMENT = args.experiment
fold = 0
SALIENCY_DIR = f'./../train_output/{EXPERIMENT}_{series_}_{fold}/saliency.npy'
IMAGES_DIR = f'./../train_output/{EXPERIMENT}_{series_}_{fold}/images_list.npy'
SEGMENTATION_DIR = f'./../train_output/{EXPERIMENT}_{series_}_{fold}/segmentations.npy'
SLIDE = 10
########## load saliency map ############
heatmap = np.load(SALIENCY_DIR)
heatmap = np.squeeze(heatmap)
######### load images and segmentations ###########
images_list = np.load(IMAGES_DIR)
images_list = np.squeeze(images_list)
segmentations = np.load(SEGMENTATION_DIR)
######## take average ##########
# len(heatmap) is smaller then maximum number of images
# if len(heatmap) < 100:
# heatmap = np.mean(abs(heatmap),axis=0)
heatmap = abs(heatmap)
fig, axes = plt.subplots(2,len(SERIES))
print(np.shape(axes))
print(np.shape(heatmap))
print(np.shape(images_list))
max_value = np.amax(heatmap[:,:,SLIDE,:])
min_value = np.amin(heatmap[:,:,SLIDE,:])
for indx in range(len(SERIES)):
print(indx)
heatmap_i = np.transpose(np.squeeze(heatmap[:,:,SLIDE,indx]))
im = axes[0,indx].imshow(np.transpose(images_list[:,:,SLIDE,indx]),cmap='gray')
axes[0,indx].imshow(np.ma.masked_where(heatmap_i < 0.10, heatmap_i),vmin=min_value, vmax=max_value*0.5, alpha=0.25, cmap='jet')
axes[0,indx].set_title(SERIES[indx])
axes[0,indx].set_axis_off()
axes[1,indx].set_axis_off()
cbar = fig.colorbar(im, ax=axes.ravel().tolist(), shrink=0.5, orientation='horizontal')
cbar.set_ticks([min_value,max_value])
cbar.set_ticklabels(['less important', 'important'])
fig.suptitle('Saliency map', fontsize=16)
plt.savefig(f'./../train_output/{EXPERIMENT}_{series_}_{fold}/saliency_map.png', dpi=300)