opschonen van scripts. Update van saliency visualisatie.

This commit is contained in:
Stefan 2022-03-21 14:31:44 +01:00
parent 02d5b371d6
commit 01f458d0db
5 changed files with 157 additions and 128 deletions

View File

@ -42,7 +42,7 @@ INPUT_SHAPE = (192, 192, 24, len(SERIES))
IMAGE_SHAPE = INPUT_SHAPE[:3] IMAGE_SHAPE = INPUT_SHAPE[:3]
DATA_SPLIT_INDEX = read_yaml_to_dict('./../data/Nijmegen paths/train_val_test_idxs.yml') DATA_SPLIT_INDEX = read_yaml_to_dict('./../data/Nijmegen paths/train_val_test_idxs.yml')
TEST_INDEX = DATA_SPLIT_INDEX['val_set0'] TEST_INDEX = DATA_SPLIT_INDEX['test_set0']
N_CPUS = 12 N_CPUS = 12

View File

@ -1,53 +1,60 @@
import sys import argparse
from os import path from os import path
import SimpleITK as sitk import SimpleITK as sitk
import tensorflow as tf import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.models import load_model from tensorflow.keras.models import load_model
from focal_loss import BinaryFocalLoss
import json
import matplotlib.pyplot as plt
import numpy as np import numpy as np
from sfransen.Saliency.base import *
from sfransen.Saliency.integrated_gradients import *
from sfransen.utils_quintin import * from sfransen.utils_quintin import *
from sfransen.DWI_exp import preprocess from sfransen.DWI_exp import preprocess
from sfransen.DWI_exp.helpers import * from sfransen.DWI_exp.helpers import *
from sfransen.DWI_exp.callbacks import dice_coef from sfransen.DWI_exp.callbacks import dice_coef
from sfransen.FROC.blob_preprocess import * from sfransen.FROC.blob_preprocess import *
from sfransen.FROC.cal_froc_from_np import * from sfransen.FROC.cal_froc_from_np import *
from sfransen.load_images import load_images_parrallel
from sfransen.Saliency.base import *
from sfransen.Saliency.integrated_gradients import *
parser = argparse.ArgumentParser(
description='Calculate the froc metrics and store in froc_metrics.yml')
parser.add_argument('-experiment',
help='Title of experiment')
parser.add_argument('--series', '-s',
metavar='[series_name]', required=True, nargs='+',
help='List of series to include')
args = parser.parse_args()
######## CUDA ################
os.environ["CUDA_VISIBLE_DEVICES"] = "2"
######## constants #############
SERIES = args.series
series_ = '_'.join(args.series)
EXPERIMENT = args.experiment
MODEL_PATH = f'./../train_output/{EXPERIMENT}_{series_}/models/{EXPERIMENT}_{series_}.h5'
YAML_DIR = f'./../train_output/{EXPERIMENT}_{series_}'
# train_10h_t2_b50_b400_b800_b1400_adc
SERIES = ['t2','b50','b400','b800','b1400','adc']
MODEL_PATH = f'./../train_output/train_10h_t2_b50_b400_b800_b1400_adc/models/train_10h_t2_b50_b400_b800_b1400_adc.h5'
YAML_DIR = f'./../train_output/train_10h_t2_b50_b400_b800_b1400_adc'
################ constants ############
DATA_DIR = "./../data/Nijmegen paths/" DATA_DIR = "./../data/Nijmegen paths/"
TARGET_SPACING = (0.5, 0.5, 3) TARGET_SPACING = (0.5, 0.5, 3)
INPUT_SHAPE = (192, 192, 24, len(SERIES)) INPUT_SHAPE = (192, 192, 24, len(SERIES))
IMAGE_SHAPE = INPUT_SHAPE[:3] IMAGE_SHAPE = INPUT_SHAPE[:3]
# import val_indx froc_metrics = read_yaml_to_dict(f'{YAML_DIR}/froc_metrics.yml')
# DATA_SPLIT_INDEX = read_yaml_to_dict('./../data/Nijmegen paths/train_val_test_idxs.yml') top_10_idx = np.argsort(froc_metrics['roc_pred'])[-1 :]
# TEST_INDEX = DATA_SPLIT_INDEX['val_set0']
experiment_path = f'./../train_output/train_10h_t2_b50_b400_b800_b1400_adc/froc_metrics.yml'
experiment_metrics = read_yaml_to_dict(experiment_path)
DATA_SPLIT_INDEX = read_yaml_to_dict('./../data/Nijmegen paths/train_val_test_idxs.yml') DATA_SPLIT_INDEX = read_yaml_to_dict('./../data/Nijmegen paths/train_val_test_idxs.yml')
TEST_INDEX = DATA_SPLIT_INDEX['test_set0'] TEST_INDEX = DATA_SPLIT_INDEX['test_set0']
top_10_idx = np.argsort(experiment_metrics['roc_pred'])[-10:] TEST_INDEX_top10 = [TEST_INDEX[i] for i in top_10_idx]
TEST_INDEX = [TEST_INDEX[i] for i in top_10_idx]
########## load images ############## N_CPUS = 12
images, image_paths = {s: [] for s in SERIES}, {}
segmentations = [] ########## load images in parrallel ##############
print_(f"> Loading images into RAM...") print_(f"> Loading images into RAM...")
# read paths from txt
image_paths = {}
for s in SERIES: for s in SERIES:
with open(path.join(DATA_DIR, f"{s}.txt"), 'r') as f: with open(path.join(DATA_DIR, f"{s}.txt"), 'r') as f:
image_paths[s] = [l.strip() for l in f.readlines()] image_paths[s] = [l.strip() for l in f.readlines()]
@ -55,45 +62,52 @@ with open(path.join(DATA_DIR, f"seg.txt"), 'r') as f:
seg_paths = [l.strip() for l in f.readlines()] seg_paths = [l.strip() for l in f.readlines()]
num_images = len(seg_paths) num_images = len(seg_paths)
# Read and preprocess each of the paths for each SERIES, and the segmentations. # create pool of workers
for img_idx in TEST_INDEX[:5]: #for less images pool = multiprocessing.Pool(processes=N_CPUS)
img_s = {s: sitk.ReadImage(image_paths[s][img_idx], sitk.sitkFloat32) partial_images = partial(load_images_parrallel,
for s in SERIES} seq = 'images',
seg_s = sitk.ReadImage(seg_paths[img_idx], sitk.sitkFloat32) target_shape=IMAGE_SHAPE,
img_n, seg_n = preprocess(img_s, seg_s, target_space = TARGET_SPACING)
shape=IMAGE_SHAPE, spacing=TARGET_SPACING) partial_seg = partial(load_images_parrallel,
for seq in img_n: seq = 'seg',
images[seq].append(img_n[seq]) target_shape=IMAGE_SHAPE,
segmentations.append(seg_n) target_space = TARGET_SPACING)
images_list = [images[s] for s in images.keys()] #load images
images_list = np.transpose(images_list, (1, 2, 3, 4, 0)) images = []
for s in SERIES:
image_paths_seq = image_paths[s]
image_paths_index = np.asarray(image_paths_seq)[TEST_INDEX_top10]
data_list = pool.map(partial_images,image_paths_index)
data = np.stack(data_list, axis=0)
images.append(data)
images_list = np.transpose(images, (1, 2, 3, 4, 0))
#load segmentations
seg_paths_index = np.asarray(seg_paths)[TEST_INDEX_top10]
data_list = pool.map(partial_seg,seg_paths_index)
segmentations = np.stack(data_list, axis=0)
########### load module ################## ########### load module ##################
print(' >>>>>>> LOAD MODEL <<<<<<<<<')
dependencies = { dependencies = {
'dice_coef': dice_coef 'dice_coef': dice_coef
} }
reconstructed_model = load_model(MODEL_PATH, custom_objects=dependencies) reconstructed_model = load_model(MODEL_PATH, custom_objects=dependencies)
# reconstructed_model.layers[-1].activation = tf.keras.activations.linear # reconstructed_model.layers[-1].activation = tf.keras.activations.linear
print('START prediction') ######### Build Saliency heatmap ##############
print(' >>>>>>> Build saliency map <<<<<<<<<')
ig = IntegratedGradients(reconstructed_model) ig = IntegratedGradients(reconstructed_model)
saliency_map = [] saliency_map = []
for img_idx in range(len(images_list)): for img_idx in range(len(images_list)):
input_img = np.resize(images_list[img_idx],(1,192,192,24,len(SERIES))) input_img = np.resize(images_list[img_idx],(1,192,192,24,len(SERIES)))
saliency_map.append(ig.get_mask(input_img).numpy()) saliency_map.append(ig.get_mask(input_img).numpy())
print("size saliency map is:",np.shape(saliency_map))
np.save('saliency',saliency_map)
# Christian Roest, [11-3-2022 15:30]
# input_img heeft dimensies (1, 48, 48, 8, 8)
# reconstructed_model.summary(line_length=120)
# make predictions on all val_indx
print('START saliency')
# predictions_blur = reconstructed_model.predict(images_list, batch_size=1)
print("size saliency map",np.shape(saliency_map))
np.save(f'{YAML_DIR}/saliency',saliency_map)
np.save(f'{YAML_DIR}/images_list',images_list)
np.save(f'{YAML_DIR}/segmentations',segmentations)

View File

@ -1,90 +1,57 @@
import argparse
import numpy as np import numpy as np
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import matplotlib.cm as cm # import matplotlib.cm as cm
heatmap = np.load('saliency.npy') parser = argparse.ArgumentParser(
print(np.shape(heatmap)) description='Calculate the froc metrics and store in froc_metrics.yml')
parser.add_argument('-experiment',
help='Title of experiment')
parser.add_argument('--series', '-s',
metavar='[series_name]', required=True, nargs='+',
help='List of series to include')
args = parser.parse_args()
########## constants #################
SERIES = args.series
series_ = '_'.join(args.series)
EXPERIMENT = args.experiment
SALIENCY_DIR = f'./../train_output/{EXPERIMENT}_{series_}/saliency.npy'
IMAGES_DIR = f'./../train_output/{EXPERIMENT}_{series_}/images_list.npy'
SEGMENTATION_DIR = f'./../train_output/{EXPERIMENT}_{series_}/segmentations.npy'
########## load saliency map ############
heatmap = np.load(SALIENCY_DIR)
heatmap = np.squeeze(heatmap) heatmap = np.squeeze(heatmap)
######### load images and segmentations ###########
images_list = np.load(IMAGES_DIR)
images_list = np.squeeze(images_list)
segmentations = np.load(SEGMENTATION_DIR)
######## take average ##########
# len(heatmap) is smaller then maximum number of images
# if len(heatmap) < 100:
# heatmap = np.mean(abs(heatmap),axis=0)
heatmap = abs(heatmap)
fig, axes = plt.subplots(2,len(SERIES))
print(np.shape(axes))
print(np.shape(heatmap)) print(np.shape(heatmap))
print(np.shape(images_list))
### take average over 5 #########
heatmap = np.mean(abs(heatmap),axis=0)
print(np.shape(heatmap))
SERIES = ['t2','b50','b400','b800','b1400','adc']
fig, axes = plt.subplots(1,6)
max_value = np.amax(heatmap) max_value = np.amax(heatmap)
pri
min_value = np.amin(heatmap) min_value = np.amin(heatmap)
# vmin vmax van hele heatmap voor scaling in imshow
# cmap naar grey
im = axes[0].imshow(np.squeeze(heatmap[:,:,12,0])) for indx in range(len(SERIES)):
axes[1].imshow(np.squeeze(heatmap[:,:,12,1]), vmin=min_value, vmax=max_value) print(indx)
axes[2].imshow(np.squeeze(heatmap[:,:,12,2]), vmin=min_value, vmax=max_value) axes[0,indx].imshow(images_list[:,:,12,indx],cmap='gray')
axes[3].imshow(np.squeeze(heatmap[:,:,12,3]), vmin=min_value, vmax=max_value) im = axes[1,indx].imshow(np.squeeze(heatmap[:,:,12,indx]),vmin=min_value, vmax=max_value)
axes[4].imshow(np.squeeze(heatmap[:,:,12,4]), vmin=min_value, vmax=max_value) axes[0,indx].set_title(SERIES[indx])
axes[5].imshow(np.squeeze(heatmap[:,:,12,5]), vmin=min_value, vmax=max_value) axes[0,indx].set_axis_off()
axes[1,indx].set_axis_off()
axes[0].set_title("t2")
axes[1].set_title("b50")
axes[2].set_title("b400")
axes[3].set_title("b800")
axes[4].set_title("b1400")
axes[5].set_title("adc")
cbar = fig.colorbar(im, ax=axes.ravel().tolist(), shrink=0.5, orientation='horizontal') cbar = fig.colorbar(im, ax=axes.ravel().tolist(), shrink=0.5, orientation='horizontal')
cbar.set_ticks([-0.1,0,0.1]) cbar.set_ticks([min_value,max_value])
cbar.set_ticklabels(['less importance', '0', 'important']) cbar.set_ticklabels(['less important', 'important'])
fig.suptitle('Average saliency maps over the 5 highest predictions', fontsize=16) fig.suptitle('Saliency map', fontsize=16)
plt.show() plt.savefig(f'./../train_output/{EXPERIMENT}_{series_}/saliency_map.png', dpi=300)
quit()
#take one image out
heatmap = np.squeeze(heatmap[0])
import numpy as np
import matplotlib.pyplot as plt
# Fixing random state for reproducibility
np.random.seed(19680801)
class IndexTracker:
def __init__(self, ax, X):
self.ax = ax
ax.set_title('use scroll wheel to navigate images')
self.X = X
rows, cols, self.slices = X.shape
self.ind = self.slices//2
self.im = ax.imshow(self.X[:, :, self.ind], cmap='jet')
self.update()
def on_scroll(self, event):
print("%s %s" % (event.button, event.step))
if event.button == 'up':
self.ind = (self.ind + 1) % self.slices
else:
self.ind = (self.ind - 1) % self.slices
self.update()
def update(self):
self.im.set_data(self.X[:, :, self.ind])
self.ax.set_ylabel('slice %s' % self.ind)
self.im.axes.figure.canvas.draw()
plt.figure(0)
fig, ax = plt.subplots(1, 1)
tracker = IndexTracker(ax, heatmap[:,:,:,5])
fig.canvas.mpl_connect('scroll_event', tracker.on_scroll)
plt.show()
plt.figure(1)
fig, ax = plt.subplots(1, 1)
tracker = IndexTracker(ax, heatmap[:,:,:,3])
fig.canvas.mpl_connect('scroll_event', tracker.on_scroll)
plt.show()

48
scripts/scroll_trough.py Executable file
View File

@ -0,0 +1,48 @@
quit()
#take one image out
heatmap = np.squeeze(heatmap[0])
import numpy as np
import matplotlib.pyplot as plt
# Fixing random state for reproducibility
np.random.seed(19680801)
class IndexTracker:
def __init__(self, ax, X):
self.ax = ax
ax.set_title('use scroll wheel to navigate images')
self.X = X
rows, cols, self.slices = X.shape
self.ind = self.slices//2
self.im = ax.imshow(self.X[:, :, self.ind], cmap='jet')
self.update()
def on_scroll(self, event):
print("%s %s" % (event.button, event.step))
if event.button == 'up':
self.ind = (self.ind + 1) % self.slices
else:
self.ind = (self.ind - 1) % self.slices
self.update()
def update(self):
self.im.set_data(self.X[:, :, self.ind])
self.ax.set_ylabel('slice %s' % self.ind)
self.im.axes.figure.canvas.draw()
plt.figure(0)
fig, ax = plt.subplots(1, 1)
tracker = IndexTracker(ax, heatmap[:,:,:,5])
fig.canvas.mpl_connect('scroll_event', tracker.on_scroll)
plt.show()
plt.figure(1)
fig, ax = plt.subplots(1, 1)
tracker = IndexTracker(ax, heatmap[:,:,:,3])
fig.canvas.mpl_connect('scroll_event', tracker.on_scroll)
plt.show()

0
scripts/scroll_trough.txt Executable file
View File