fast-mri/scripts/16.plot_paper_saliency.py

89 lines
3.6 KiB
Python
Executable File

import argparse
from ast import Slice
from email.mime import image
import numpy as np
import matplotlib.pyplot as plt
import SimpleITK as sitk
parser = argparse.ArgumentParser(
description='Calculate the froc metrics and store in froc_metrics.yml')
parser.add_argument('-experiment',
help='Title of experiment')
parser.add_argument('--series', '-s',
metavar='[series_name]', required=True, nargs='+',
help='List of series to include')
args = parser.parse_args()
########## constants #################
SERIES = args.series
series_ = '_'.join(args.series)
EXPERIMENT = args.experiment
fold = 0
experiments = ['calc_exp_t2_b1400calc3_adccalc3_0','calc_exp_t2_b1400calc2_adccalc2_0','calc_exp_t2_b1400calc_adccalc_0']
fig, axes = plt.subplots(3,len(SERIES)+1)
for idx,experiment in enumerate(experiments):
print(idx)
SALIENCY_DIR = f'./../train_output/{experiment}/saliency_new.npy' #_new23
IMAGES_DIR = f'./../train_output/{experiment}/images_list_new.npy' #_new23
SEGMENTATION_DIR = f'./../train_output/{experiment}/segmentations_new.npy' #_new23
predictions_DIR = f'./../train_output/{experiment}/predictions_new.npy' #_new23
SLIDE = 10 #pat_idx 371 = pat0623
# SLIDE = 7 #pat_idx ?? = pat023
########## load saliency map ############
heatmap = np.load(SALIENCY_DIR)
heatmap = np.squeeze(heatmap)
######### load images and segmentations ###########
images_list = np.load(IMAGES_DIR)
images_list = np.squeeze(images_list)
segmentations = np.squeeze(np.load(SEGMENTATION_DIR))
predictions = np.squeeze(np.load(predictions_DIR))
######## take average ##########
# len(heatmap) is smaller then maximum number of images
# if len(heatmap) < 100:
# heatmap = np.mean(abs(heatmap),axis=0)
heatmap = abs(heatmap)
print(np.shape(predictions))
print(np.shape(segmentations))
print(np.shape(images_list))
max_value = np.amax(heatmap[:,:,SLIDE,:])
min_value = np.amin(heatmap[:,:,SLIDE,:])
TITLES = ['$T2_{tra}$','$DWI_{b1400}$','ADC','Prediction']
titles = ['All b-values','Omitting b800','Omitting b400']
for indx in range(len(SERIES)+1):
print(indx)
if indx is len(SERIES):
im = axes[idx,indx].imshow(predictions[SLIDE,:,:],cmap='gray')
print(np.amax(predictions[SLIDE,:,:]))
seg = segmentations[SLIDE,:,:]
axes[idx,indx].imshow(np.ma.masked_where(seg < 0.10, seg),alpha=0.5, vmin=np.amin(seg), vmax=np.amax(seg), cmap='bwr')
if idx is 0:
axes[idx,indx].set_title(TITLES[indx])
axes[idx,indx].set_axis_off()
axes[idx,indx].set_axis_off()
else:
heatmap_i = np.transpose(np.squeeze(heatmap[:,:,SLIDE,indx]))
im = axes[idx,indx].imshow(np.transpose(images_list[:,:,SLIDE,indx]),cmap='gray')
axes[idx,indx].imshow(np.ma.masked_where(heatmap_i < 0.10, heatmap_i),vmin=min_value, vmax=max_value*0.5, alpha=0.25, cmap='jet')
if idx is 0:
axes[idx,indx].set_title(TITLES[indx])
# axes[idx,indx].set_axis_off()
axes[idx,indx].set_yticks([])
axes[idx,indx].set_xticks([])
if indx is 0:
axes[idx,indx].set_ylabel(titles[idx])
# cbar = fig.colorbar(im, ax=axes.ravel().tolist(), shrink=0.5, orientation='horizontal')
# cbar.set_ticks([min_value,max_value])
# cbar.set_ticklabels(['less important', 'important'])
# fig.suptitle('Saliency map', fontsize=16)
plt.savefig(f'./../train_output/saliency_map_paper_pat23.png', dpi=300)