diff --git a/scripts/4.frocs.py b/scripts/4.frocs.py index c1d05bb..13f346a 100755 --- a/scripts/4.frocs.py +++ b/scripts/4.frocs.py @@ -42,7 +42,7 @@ INPUT_SHAPE = (192, 192, 24, len(SERIES)) IMAGE_SHAPE = INPUT_SHAPE[:3] DATA_SPLIT_INDEX = read_yaml_to_dict('./../data/Nijmegen paths/train_val_test_idxs.yml') -TEST_INDEX = DATA_SPLIT_INDEX['val_set0'] +TEST_INDEX = DATA_SPLIT_INDEX['test_set0'] N_CPUS = 12 diff --git a/scripts/6.saliency_map.py b/scripts/6.saliency_map.py index dc41f1a..c032d20 100755 --- a/scripts/6.saliency_map.py +++ b/scripts/6.saliency_map.py @@ -1,53 +1,60 @@ -import sys +import argparse from os import path import SimpleITK as sitk import tensorflow as tf -from tensorflow import keras from tensorflow.keras.models import load_model -from focal_loss import BinaryFocalLoss -import json -import matplotlib.pyplot as plt import numpy as np -from sfransen.Saliency.base import * -from sfransen.Saliency.integrated_gradients import * from sfransen.utils_quintin import * from sfransen.DWI_exp import preprocess from sfransen.DWI_exp.helpers import * from sfransen.DWI_exp.callbacks import dice_coef from sfransen.FROC.blob_preprocess import * from sfransen.FROC.cal_froc_from_np import * +from sfransen.load_images import load_images_parrallel +from sfransen.Saliency.base import * +from sfransen.Saliency.integrated_gradients import * +parser = argparse.ArgumentParser( + description='Calculate the froc metrics and store in froc_metrics.yml') +parser.add_argument('-experiment', + help='Title of experiment') +parser.add_argument('--series', '-s', + metavar='[series_name]', required=True, nargs='+', + help='List of series to include') +args = parser.parse_args() +######## CUDA ################ +os.environ["CUDA_VISIBLE_DEVICES"] = "2" +######## constants ############# +SERIES = args.series +series_ = '_'.join(args.series) +EXPERIMENT = args.experiment + +MODEL_PATH = f'./../train_output/{EXPERIMENT}_{series_}/models/{EXPERIMENT}_{series_}.h5' +YAML_DIR = f'./../train_output/{EXPERIMENT}_{series_}' -# train_10h_t2_b50_b400_b800_b1400_adc -SERIES = ['t2','b50','b400','b800','b1400','adc'] -MODEL_PATH = f'./../train_output/train_10h_t2_b50_b400_b800_b1400_adc/models/train_10h_t2_b50_b400_b800_b1400_adc.h5' -YAML_DIR = f'./../train_output/train_10h_t2_b50_b400_b800_b1400_adc' -################ constants ############ DATA_DIR = "./../data/Nijmegen paths/" TARGET_SPACING = (0.5, 0.5, 3) INPUT_SHAPE = (192, 192, 24, len(SERIES)) IMAGE_SHAPE = INPUT_SHAPE[:3] -# import val_indx -# DATA_SPLIT_INDEX = read_yaml_to_dict('./../data/Nijmegen paths/train_val_test_idxs.yml') -# TEST_INDEX = DATA_SPLIT_INDEX['val_set0'] +froc_metrics = read_yaml_to_dict(f'{YAML_DIR}/froc_metrics.yml') +top_10_idx = np.argsort(froc_metrics['roc_pred'])[-1 :] -experiment_path = f'./../train_output/train_10h_t2_b50_b400_b800_b1400_adc/froc_metrics.yml' -experiment_metrics = read_yaml_to_dict(experiment_path) DATA_SPLIT_INDEX = read_yaml_to_dict('./../data/Nijmegen paths/train_val_test_idxs.yml') TEST_INDEX = DATA_SPLIT_INDEX['test_set0'] -top_10_idx = np.argsort(experiment_metrics['roc_pred'])[-10:] -TEST_INDEX = [TEST_INDEX[i] for i in top_10_idx] +TEST_INDEX_top10 = [TEST_INDEX[i] for i in top_10_idx] -########## load images ############## -images, image_paths = {s: [] for s in SERIES}, {} -segmentations = [] +N_CPUS = 12 + +########## load images in parrallel ############## print_(f"> Loading images into RAM...") +# read paths from txt +image_paths = {} for s in SERIES: with open(path.join(DATA_DIR, f"{s}.txt"), 'r') as f: image_paths[s] = [l.strip() for l in f.readlines()] @@ -55,45 +62,52 @@ with open(path.join(DATA_DIR, f"seg.txt"), 'r') as f: seg_paths = [l.strip() for l in f.readlines()] num_images = len(seg_paths) -# Read and preprocess each of the paths for each SERIES, and the segmentations. -for img_idx in TEST_INDEX[:5]: #for less images - img_s = {s: sitk.ReadImage(image_paths[s][img_idx], sitk.sitkFloat32) - for s in SERIES} - seg_s = sitk.ReadImage(seg_paths[img_idx], sitk.sitkFloat32) - img_n, seg_n = preprocess(img_s, seg_s, - shape=IMAGE_SHAPE, spacing=TARGET_SPACING) - for seq in img_n: - images[seq].append(img_n[seq]) - segmentations.append(seg_n) +# create pool of workers +pool = multiprocessing.Pool(processes=N_CPUS) +partial_images = partial(load_images_parrallel, + seq = 'images', + target_shape=IMAGE_SHAPE, + target_space = TARGET_SPACING) +partial_seg = partial(load_images_parrallel, + seq = 'seg', + target_shape=IMAGE_SHAPE, + target_space = TARGET_SPACING) -images_list = [images[s] for s in images.keys()] -images_list = np.transpose(images_list, (1, 2, 3, 4, 0)) +#load images +images = [] +for s in SERIES: + image_paths_seq = image_paths[s] + image_paths_index = np.asarray(image_paths_seq)[TEST_INDEX_top10] + data_list = pool.map(partial_images,image_paths_index) + data = np.stack(data_list, axis=0) + images.append(data) + +images_list = np.transpose(images, (1, 2, 3, 4, 0)) + +#load segmentations +seg_paths_index = np.asarray(seg_paths)[TEST_INDEX_top10] +data_list = pool.map(partial_seg,seg_paths_index) +segmentations = np.stack(data_list, axis=0) ########### load module ################## +print(' >>>>>>> LOAD MODEL <<<<<<<<<') + dependencies = { 'dice_coef': dice_coef } reconstructed_model = load_model(MODEL_PATH, custom_objects=dependencies) - # reconstructed_model.layers[-1].activation = tf.keras.activations.linear -print('START prediction') +######### Build Saliency heatmap ############## +print(' >>>>>>> Build saliency map <<<<<<<<<') ig = IntegratedGradients(reconstructed_model) saliency_map = [] for img_idx in range(len(images_list)): input_img = np.resize(images_list[img_idx],(1,192,192,24,len(SERIES))) saliency_map.append(ig.get_mask(input_img).numpy()) - print("size saliency map is:",np.shape(saliency_map)) - -np.save('saliency',saliency_map) - -# Christian Roest, [11-3-2022 15:30] -# input_img heeft dimensies (1, 48, 48, 8, 8) - -# reconstructed_model.summary(line_length=120) - -# make predictions on all val_indx -print('START saliency') -# predictions_blur = reconstructed_model.predict(images_list, batch_size=1) +print("size saliency map",np.shape(saliency_map)) +np.save(f'{YAML_DIR}/saliency',saliency_map) +np.save(f'{YAML_DIR}/images_list',images_list) +np.save(f'{YAML_DIR}/segmentations',segmentations) \ No newline at end of file diff --git a/scripts/7.Visualize_saliency.py b/scripts/7.Visualize_saliency.py index 4e7b5bd..37ae310 100755 --- a/scripts/7.Visualize_saliency.py +++ b/scripts/7.Visualize_saliency.py @@ -1,90 +1,57 @@ +import argparse import numpy as np import matplotlib.pyplot as plt -import matplotlib.cm as cm +# import matplotlib.cm as cm -heatmap = np.load('saliency.npy') -print(np.shape(heatmap)) +parser = argparse.ArgumentParser( + description='Calculate the froc metrics and store in froc_metrics.yml') +parser.add_argument('-experiment', + help='Title of experiment') +parser.add_argument('--series', '-s', + metavar='[series_name]', required=True, nargs='+', + help='List of series to include') +args = parser.parse_args() + +########## constants ################# +SERIES = args.series +series_ = '_'.join(args.series) +EXPERIMENT = args.experiment +SALIENCY_DIR = f'./../train_output/{EXPERIMENT}_{series_}/saliency.npy' +IMAGES_DIR = f'./../train_output/{EXPERIMENT}_{series_}/images_list.npy' +SEGMENTATION_DIR = f'./../train_output/{EXPERIMENT}_{series_}/segmentations.npy' + +########## load saliency map ############ +heatmap = np.load(SALIENCY_DIR) heatmap = np.squeeze(heatmap) + +######### load images and segmentations ########### +images_list = np.load(IMAGES_DIR) +images_list = np.squeeze(images_list) +segmentations = np.load(SEGMENTATION_DIR) +######## take average ########## +# len(heatmap) is smaller then maximum number of images +# if len(heatmap) < 100: + # heatmap = np.mean(abs(heatmap),axis=0) +heatmap = abs(heatmap) + +fig, axes = plt.subplots(2,len(SERIES)) +print(np.shape(axes)) print(np.shape(heatmap)) - - -### take average over 5 ######### -heatmap = np.mean(abs(heatmap),axis=0) -print(np.shape(heatmap)) - -SERIES = ['t2','b50','b400','b800','b1400','adc'] -fig, axes = plt.subplots(1,6) +print(np.shape(images_list)) max_value = np.amax(heatmap) -pri min_value = np.amin(heatmap) -# vmin vmax van hele heatmap voor scaling in imshow -# cmap naar grey -im = axes[0].imshow(np.squeeze(heatmap[:,:,12,0])) -axes[1].imshow(np.squeeze(heatmap[:,:,12,1]), vmin=min_value, vmax=max_value) -axes[2].imshow(np.squeeze(heatmap[:,:,12,2]), vmin=min_value, vmax=max_value) -axes[3].imshow(np.squeeze(heatmap[:,:,12,3]), vmin=min_value, vmax=max_value) -axes[4].imshow(np.squeeze(heatmap[:,:,12,4]), vmin=min_value, vmax=max_value) -axes[5].imshow(np.squeeze(heatmap[:,:,12,5]), vmin=min_value, vmax=max_value) - -axes[0].set_title("t2") -axes[1].set_title("b50") -axes[2].set_title("b400") -axes[3].set_title("b800") -axes[4].set_title("b1400") -axes[5].set_title("adc") +for indx in range(len(SERIES)): + print(indx) + axes[0,indx].imshow(images_list[:,:,12,indx],cmap='gray') + im = axes[1,indx].imshow(np.squeeze(heatmap[:,:,12,indx]),vmin=min_value, vmax=max_value) + axes[0,indx].set_title(SERIES[indx]) + axes[0,indx].set_axis_off() + axes[1,indx].set_axis_off() cbar = fig.colorbar(im, ax=axes.ravel().tolist(), shrink=0.5, orientation='horizontal') -cbar.set_ticks([-0.1,0,0.1]) -cbar.set_ticklabels(['less importance', '0', 'important']) -fig.suptitle('Average saliency maps over the 5 highest predictions', fontsize=16) -plt.show() +cbar.set_ticks([min_value,max_value]) +cbar.set_ticklabels(['less important', 'important']) +fig.suptitle('Saliency map', fontsize=16) +plt.savefig(f'./../train_output/{EXPERIMENT}_{series_}/saliency_map.png', dpi=300) -quit() - -#take one image out -heatmap = np.squeeze(heatmap[0]) - -import numpy as np -import matplotlib.pyplot as plt - -# Fixing random state for reproducibility -np.random.seed(19680801) - - -class IndexTracker: - def __init__(self, ax, X): - self.ax = ax - ax.set_title('use scroll wheel to navigate images') - - self.X = X - rows, cols, self.slices = X.shape - self.ind = self.slices//2 - - self.im = ax.imshow(self.X[:, :, self.ind], cmap='jet') - self.update() - - def on_scroll(self, event): - print("%s %s" % (event.button, event.step)) - if event.button == 'up': - self.ind = (self.ind + 1) % self.slices - else: - self.ind = (self.ind - 1) % self.slices - self.update() - - def update(self): - self.im.set_data(self.X[:, :, self.ind]) - self.ax.set_ylabel('slice %s' % self.ind) - self.im.axes.figure.canvas.draw() - -plt.figure(0) -fig, ax = plt.subplots(1, 1) -tracker = IndexTracker(ax, heatmap[:,:,:,5]) -fig.canvas.mpl_connect('scroll_event', tracker.on_scroll) -plt.show() - -plt.figure(1) -fig, ax = plt.subplots(1, 1) -tracker = IndexTracker(ax, heatmap[:,:,:,3]) -fig.canvas.mpl_connect('scroll_event', tracker.on_scroll) -plt.show() diff --git a/scripts/scroll_trough.py b/scripts/scroll_trough.py new file mode 100755 index 0000000..0413f1a --- /dev/null +++ b/scripts/scroll_trough.py @@ -0,0 +1,48 @@ + +quit() +#take one image out +heatmap = np.squeeze(heatmap[0]) + +import numpy as np +import matplotlib.pyplot as plt + +# Fixing random state for reproducibility +np.random.seed(19680801) + + +class IndexTracker: + def __init__(self, ax, X): + self.ax = ax + ax.set_title('use scroll wheel to navigate images') + + self.X = X + rows, cols, self.slices = X.shape + self.ind = self.slices//2 + + self.im = ax.imshow(self.X[:, :, self.ind], cmap='jet') + self.update() + + def on_scroll(self, event): + print("%s %s" % (event.button, event.step)) + if event.button == 'up': + self.ind = (self.ind + 1) % self.slices + else: + self.ind = (self.ind - 1) % self.slices + self.update() + + def update(self): + self.im.set_data(self.X[:, :, self.ind]) + self.ax.set_ylabel('slice %s' % self.ind) + self.im.axes.figure.canvas.draw() + +plt.figure(0) +fig, ax = plt.subplots(1, 1) +tracker = IndexTracker(ax, heatmap[:,:,:,5]) +fig.canvas.mpl_connect('scroll_event', tracker.on_scroll) +plt.show() + +plt.figure(1) +fig, ax = plt.subplots(1, 1) +tracker = IndexTracker(ax, heatmap[:,:,:,3]) +fig.canvas.mpl_connect('scroll_event', tracker.on_scroll) +plt.show() diff --git a/scripts/scroll_trough.txt b/scripts/scroll_trough.txt new file mode 100755 index 0000000..e69de29