add scripts

This commit is contained in:
Stefan 2022-03-21 10:14:00 +01:00
parent 8cd3d865da
commit be5b392456
8 changed files with 852 additions and 1 deletions

1
.gitignore vendored
View File

@ -143,7 +143,6 @@ cython_debug/
/old_code/
/data/
/job_scripts/
/scripts/
/temp/
/slurms/
*.out

203
scripts/1.U-net_chris.py Executable file
View File

@ -0,0 +1,203 @@
import multiprocessing
from os import path
import argparse
import time
from datetime import datetime
import sys
# sys.path.append('./../code')
# from utils_quintin import *
# from sfransen.utils_quintin import *
# sys.path.append('./../code/DWI_exp')
# from callbacks import IntermediateImages, dice_coef
# from callbacks import RocCallback
from sfransen.DWI_exp import IntermediateImages, dice_coef
from sfransen.DWI_exp.preprocessing_function import preprocess
import yaml
import numpy as np
from tqdm import tqdm
import tensorflow as tf
tf.compat.v1.disable_eager_execution()
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, CSVLogger
from tensorflow.keras.optimizers import Adam
from sklearn.model_selection import KFold
from sfransen.DWI_exp.helpers import *
from sfransen.DWI_exp.batchgenerator import BatchGenerator
from sfransen.DWI_exp.unet import build_dual_attention_unet
from focal_loss import BinaryFocalLoss
parser = argparse.ArgumentParser(
description='Train a U-Net model for segmentation/detection tasks.' +
'using cross-validation.')
parser.add_argument('--series', '-s',
metavar='[series_name]', required=True, nargs='+',
help='List of series to include, must correspond with' +
"path files in ./data/")
parser.add_argument('-experiment',
help='add experiment title to store the files correctly: test_b50_b400_b800'
)
args = parser.parse_args()
# Determine the number of input series
num_series = len(args.series)
# Identify this job by the series included in the training
# Output folder will have this name, e.g.: b0_b50_b100
# JOB_NAME = '_'.join(args.series)
JOB_NAME = args.experiment
DATA_DIR = "./../data/Nijmegen paths/"
# DATA_DIR = "./../data/new/"
# PROJECT_DIR = f"/data/pca-rad/sfransen/train_output/{args.experiment}"
PROJECT_DIR = f"/data/pca-rad/sfransen/train_output/{JOB_NAME}"
# 2 x 2mm2 in-plane resolution, 3.6mm slice thickness
TARGET_SPACING = (0.5, 0.5, 3)
INPUT_SHAPE = (192, 192, 24, num_series) #(64, 64, 20, num_series)
IMAGE_SHAPE = INPUT_SHAPE[:3]
OUTPUT_SHAPE = (192, 192, 24, 1) # One output channel (segmentation)
# Hyperparameters
FOCAL_LOSS_GAMMA = 2
INITIAL_LEARNING_RATE = 1e-4
MAX_EPOCHS = 600
EARLY_STOPPING = 50
# increase batch size
BATCH_SIZE = 12
MODEL_SELECTION_METRIC = 'val_loss'
MODEL_SELECTION_DIRECTION = "min" # Change to 'max' if higher value is better
EARLY_STOPPING_METRIC = 'val_loss'
EARLY_STOPPING_DIRECTION = "min" # Change to 'max' if higher value is better
# Training configuration
# add metric ROC_AUC
TRAINING_METRICS = ["binary_crossentropy", "binary_accuracy", dice_coef]
loss = BinaryFocalLoss(gamma=FOCAL_LOSS_GAMMA)
optimizer = Adam(learning_rate=INITIAL_LEARNING_RATE)
# Create folder structure in the output directory
if path.exists(PROJECT_DIR):
prepare_project_dir(PROJECT_DIR+'_(2)')
else:
prepare_project_dir(PROJECT_DIR)
#save params to yaml
params = {
"focal_loss_gamma": FOCAL_LOSS_GAMMA,
"initial_learning_rate": INITIAL_LEARNING_RATE,
"max_epochs": MAX_EPOCHS,
"MODEL_SELECTION_METRIC": MODEL_SELECTION_METRIC,
"EARLY_STOPPING_METRIC": EARLY_STOPPING_METRIC,
"train_output_dir": PROJECT_DIR,
"batch_size": BATCH_SIZE,
"optimizer": optimizer,
"loss": loss,
"datetime": print(datetime.now().strftime("%Y-%m-%d"))}
dump_dict_to_yaml(params, f"{PROJECT_DIR}", filename=f"params")
# Build the U-Net model
detection_model = build_dual_attention_unet(INPUT_SHAPE)
detection_model.summary(line_length=120)
# Load all numpy images into RAM
images, image_paths = {s: [] for s in args.series}, {}
segmentations = []
print_(f"> Loading images into RAM...")
# Read the image paths from the data directory.
# Texts files are expected to have the name "[series_name].txt"
for s in args.series:
with open(path.join(DATA_DIR, f"{s}.txt"), 'r') as f:
image_paths[s] = [l.strip() for l in f.readlines()]
with open(path.join(DATA_DIR, f"seg.txt"), 'r') as f:
seg_paths = [l.strip() for l in f.readlines()]
num_images = len(seg_paths)
# Read and preprocess each of the paths for each series, and the segmentations.
for img_idx in tqdm(range(num_images)): # [:40]):for less images
img_s = {s: sitk.ReadImage(image_paths[s][img_idx], sitk.sitkFloat32)
for s in args.series}
seg_s = sitk.ReadImage(seg_paths[img_idx], sitk.sitkFloat32)
img_n, seg_n = preprocess(img_s, seg_s,
shape=IMAGE_SHAPE, spacing=TARGET_SPACING)
for seq in img_n:
images[seq].append(img_n[seq])
segmentations.append(seg_n)
# Split train and validation
# We use KFold to split the data, but we don't actually do cross validation, we
# just use it to split the data 1:9.
# kfold = KFold(10, shuffle=True, random_state=123)
# train_idxs, valid_idxs = list(kfold.split(segmentations))[0]
# train_idxs = list(train_idxs)
# valid_idxs = list(valid_idxs)
yml_paths = read_yaml_to_dict('./../data/Nijmegen paths/train_val_test_idxs.yml')
train_idxs = yml_paths['train_set0']
valid_idxs = yml_paths['val_set0']
detection_model.compile(
optimizer=optimizer,
loss=loss,
metrics=TRAINING_METRICS
)
train_generator = BatchGenerator(images, segmentations,
sequences=args.series,
shape=IMAGE_SHAPE,
indexes=train_idxs,
batch_size=BATCH_SIZE,
shuffle=True,
augmentation_function=augment
)
valid_generator = get_generator(images, segmentations,
sequences=args.series,
shape=IMAGE_SHAPE,
indexes=valid_idxs,
batch_size=None,
shuffle=False,
augmentation=None
)
valid_data = next(valid_generator)
print_(f"The shape of valid_data input = {np.shape(valid_data[0])}")
print_(f"The shape of valid_data label = {np.shape(valid_data[1])}")
callbacks = [
EarlyStopping(
monitor=EARLY_STOPPING_METRIC,
mode=EARLY_STOPPING_DIRECTION,
patience=EARLY_STOPPING,
verbose=1),
ModelCheckpoint(
filepath=path.join(PROJECT_DIR, "models", JOB_NAME + ".h5"),
monitor=MODEL_SELECTION_METRIC,
mode=MODEL_SELECTION_DIRECTION,
verbose=1,
save_best_only=True),
# ModelCheckpoint(
# filepath=path.join(PROJECT_DIR, "models_dice", JOB_NAME + ".h5"),
# monitor='val_dice_coef',
# mode='max',
# verbose=0,
# save_best_only=True),
CSVLogger(
filename=path.join(PROJECT_DIR, "logs", f"{JOB_NAME}.csv")),
IntermediateImages(
validation_set=valid_data,
sequences=args.series,
prefix=path.join(PROJECT_DIR, "output", JOB_NAME),
num_images=25)
# RocCallback(
# validation_set=valid_data,
# num_images=25)
]
detection_model.fit(train_generator,
validation_data = valid_data,
steps_per_epoch = len(train_idxs) // BATCH_SIZE,
epochs = MAX_EPOCHS,
callbacks = callbacks,
verbose = 1,
)

View File

@ -0,0 +1,107 @@
import argparse
import random
import sys
sys.path.append('./../code')
from utils_quintin import list_from_file, dump_dict_to_yaml, print_p
################################ README ######################################
# NEW - This script will create indexes for the training, validation and test
# sets. Create a yaml file with 10 sets of indexes for train, val and test. So
# that they can be used for 10 fold cross validation when needed. The filename
# should contain an integer, which is the seed used to generate the indexes. The
# filename also indicates the type of split used.
# Output structure:
# trainSet0: [indexes]
# valSet0: [indexes]
# testSet0: [indexes]
# trainSet1: [indexes]
# valSet1: [indexes]
# testSet1: [indexes]
# etc...
################################ PARSER ########################################
def parse_input_args():
parser = argparse.ArgumentParser(description='Parse arguments for splitting training, validation and the test set.')
parser.add_argument('-n',
'--num_folds',
type=int,
default=1,
help='The number of folds in total. This amount of index sets will be created.')
parser.add_argument('-p',
'--path_to_paths_file',
type=str,
default="./../data/Nijmegen paths/seg.txt",
help='Path to the .txt file containting paths to the nifti files.')
parser.add_argument('-s',
'--split',
type=str,
default="80/10/10",
help='Train/validation/test split in percentages.')
args = parser.parse_args()
# split the given split string.
args.p_train = int(args.split.split('/')[0])
args.p_val = int(args.split.split('/')[1])
args.p_test = int(args.split.split('/')[2])
assert args.p_train + args.p_val + args.p_test == 100, "The train, val, test split to sum to 100%."
return args
################################################################################
SEED = 3478
if __name__ == '__main__':
print_p('\n\nMaking Train - Validation - Test indexes based.')
# Parse some arguments
args = parse_input_args()
print_p(args)
# Read the amount of observations/subjects in the data.
t2_paths = list_from_file(args.path_to_paths_file)
num_obs = len(t2_paths)
print_p(f"Number of observations in {args.path_to_paths_file}: {len(t2_paths)}")
# Create cutoff points for training, validation and test set.
train_cutoff = int(args.p_train/100 * num_obs)
val_cutoff = int(args.p_val/100 * num_obs) + train_cutoff
test_cutoff = int(args.p_test/100 * num_obs) + val_cutoff
print(f"\ncutoffs: {train_cutoff}, {val_cutoff}, {test_cutoff}")
# Create dict that will hold all the data
data_dict = {}
data_dict["init_seed"] = SEED
data_dict["split"] = args.split
# loop over the amount of folds, that many sets will be created in a yaml file.
for set_idx in range(args.num_folds):
# Set new seed first
random.seed(SEED + set_idx)
# shuffle the indexes
indexes = list(range(num_obs))
random.shuffle(indexes)
train_idxs = indexes[:train_cutoff]
val_idxs = indexes[train_cutoff:val_cutoff]
test_idxs = indexes[val_cutoff:test_cutoff]
data_dict[f"train_set{set_idx}"] = train_idxs
data_dict[f"val_set{set_idx}"] = val_idxs
data_dict[f"test_set{set_idx}"] = test_idxs
for key in data_dict:
if type(data_dict[key]) == list:
print(f"{key}: {len(data_dict[key])}")
dump_dict_to_yaml(data_dict, "./../data", filename=f"train_val_test_idxs", verbose=False)

217
scripts/4.frocs.py Executable file
View File

@ -0,0 +1,217 @@
import sys
from os import path
import SimpleITK as sitk
import tensorflow as tf
from tensorflow.keras.models import load_model
from focal_loss import BinaryFocalLoss
import json
import matplotlib.pyplot as plt
import numpy as np
import multiprocessing
from functools import partial
sys.path.append('./../code')
from utils_quintin import *
sys.path.append('./../code/DWI_exp')
from helpers import *
from preprocessing_function import preprocess
from callbacks import dice_coef
sys.path.append('./../code/FROC')
from blob_preprocess import *
from cal_froc_from_np import *
parser = argparse.ArgumentParser(
description='Train a U-Net model for segmentation/detection tasks.' +
'using cross-validation.')
parser.add_argument('--series', '-s',
metavar='[series_name]', required=True, nargs='+',
help='List of series to include, must correspond with' +
"path files in ./data/")
args = parser.parse_args()
######## parsed inputs #############
# SERIES = ['b50', 'b400', 'b800'] #can be parsed
SERIES = args.series
series_ = '_'.join(args.series)
# Import model
# MODEL_PATH = f'./../train_output/train_10h_{series_}/models/train_10h_{series_}.h5'
# YAML_DIR = f'./../train_output/train_10h_{series_}'
MODEL_PATH = f'./../train_output/train_n0.001_{series_}/models/train_n0.001_{series_}.h5'
print(MODEL_PATH)
YAML_DIR = f'./../train_output/train_n0.001_{series_}'
################ constants ############
DATA_DIR = "./../data/Nijmegen paths/"
TARGET_SPACING = (0.5, 0.5, 3)
INPUT_SHAPE = (192, 192, 24, len(SERIES))
IMAGE_SHAPE = INPUT_SHAPE[:3]
# import val_indx
DATA_SPLIT_INDEX = read_yaml_to_dict('./../data/Nijmegen paths/train_val_test_idxs.yml')
TEST_INDEX = DATA_SPLIT_INDEX['val_set0']
########## load images ##############
images, image_paths = {s: [] for s in SERIES}, {}
segmentations = []
print_(f"> Loading images into RAM...")
for s in SERIES:
with open(path.join(DATA_DIR, f"{s}.txt"), 'r') as f:
image_paths[s] = [l.strip() for l in f.readlines()]
with open(path.join(DATA_DIR, f"seg.txt"), 'r') as f:
seg_paths = [l.strip() for l in f.readlines()]
num_images = len(seg_paths)
# Read and preprocess each of the paths for each SERIES, and the segmentations.
from typing import List
def load_images(
image_paths: str,
seq: str,
target_shape: List[int],
target_space = List[float]):
img_s = sitk.ReadImage(image_paths, sitk.sitkFloat32)
#resample
mri_tra_s = resample(img_s,
min_shape=target_shape,
method=sitk.sitkNearestNeighbor,
new_spacing=target_space)
#center crop
mri_tra_s = center_crop(mri_tra_s, shape=target_shape)
#normalize
if seq != 'seg':
filter = sitk.NormalizeImageFilter()
mri_tra_s = filter.Execute(mri_tra_s)
else:
filter = sitk.BinaryThresholdImageFilter()
filter.SetLowerThreshold(1.0)
mri_tra_s = filter.Execute(mri_tra_s)
return sitk.GetArrayFromImage(mri_tra_s).T
N_CPUS = 12
pool = multiprocessing.Pool(processes=N_CPUS)
partial_f = partial(load_images,
seq = 'images',
target_shape=IMAGE_SHAPE,
target_space = TARGET_SPACING)
images_2 = []
for s in SERIES:
image_paths_seq = image_paths[s]
image_paths_index = np.asarray(image_paths_seq)[TEST_INDEX]
data_list = pool.map(partial_f,image_paths_index)
data = np.stack(data_list, axis=0)
images_2.append(data)
# print(s)
# print(np.shape(data))
print(np.shape(images_2))
partial_f = partial(load_images,
seq = 'seg',
target_shape=IMAGE_SHAPE,
target_space = TARGET_SPACING)
seg_paths_index = np.asarray(seg_paths)[TEST_INDEX]
data_list = pool.map(partial_f,seg_paths_index)
segmentations = np.stack(data_list, axis=0)
# print("segmentations pool",np.shape(segmentations_2))
# for img_idx in TEST_INDEX: #for less images
# img_s = {s: sitk.ReadImage(image_paths[s][img_idx], sitk.sitkFloat32)
# for s in SERIES}
# seg_s = sitk.ReadImage(seg_paths[img_idx], sitk.sitkFloat32)
# img_n, seg_n = preprocess(img_s, seg_s,
# shape=IMAGE_SHAPE, spacing=TARGET_SPACING)
# for seq in img_n:
# images[seq].append(img_n[seq])
# segmentations.append(seg_n)
# print("segmentations old",np.shape(segmentations))
# # from dict to list
# # images_list = [img nmbr, [INPUT_SHAPE]]
# images_list = [images[s] for s in images.keys()]
# images_list = np.transpose(images_list, (1, 2, 3, 4, 0))
images_list = np.transpose(images_2, (1, 2, 3, 4, 0))
print("images size ",np.shape(images_list))
print("size segmentation",np.shape(segmentations))
# print("images size pool",np.shape(images_list_2))
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "2"
########### load module ##################
print(' >>>>>>> LOAD MODEL <<<<<<<<<')
dependencies = {
'dice_coef': dice_coef
}
reconstructed_model = load_model(MODEL_PATH, custom_objects=dependencies)
# reconstructed_model.summary(line_length=120)
# make predictions on all val_indx
print(' >>>>>>> START prediction <<<<<<<<<')
predictions_blur = reconstructed_model.predict(images_list, batch_size=1)
# print("The shape of the predictions list is: ",np.shape(predictions_blur))
# print(type(predictions))
# np.save('predictions',predictions)
# preprocess predictions by removing the blur and making individual blobs
print('>>>>>>>> START preprocess')
def move_dims(arr):
# UMCG numpy dimensions convention: dims = (batch, width, heigth, depth)
# Joeran numpy dimensions convention: dims = (batch, depth, heigth, width)
arr = np.moveaxis(arr, 3, 1) # Joeran has his numpy arrays ordered differently.
arr = np.moveaxis(arr, 3, 2)
return arr
# Joeran has his numpy arrays ordered differently.
predictions_blur = move_dims(np.squeeze(predictions_blur))
segmentations = move_dims(np.squeeze(segmentations))
predictions = [preprocess_softmax(pred, threshold="dynamic")[0] for pred in predictions_blur]
# Remove outer edges
zeros = np.zeros(np.shape(predictions))
test = np.squeeze(predictions)[:,:,2:190,2:190]
zeros[:,:,2:190,2:190] = test
predictions = zeros
# perform Froc
metrics = evaluate(y_true=segmentations, y_pred=predictions)
dump_dict_to_yaml(metrics, YAML_DIR, "froc_metrics", verbose=True)
# save one image
IMAGE_DIR = f'./../train_output/train_10h_{series_}'
img_s = sitk.GetImageFromArray(predictions_blur[3].squeeze())
sitk.WriteImage(img_s, f"{IMAGE_DIR}/predictions_blur_001.nii.gz")
img_s = sitk.GetImageFromArray(predictions[3].squeeze())
sitk.WriteImage(img_s, f"{IMAGE_DIR}/predictions_001.nii.gz")
img_s = sitk.GetImageFromArray(segmentations[3].squeeze())
sitk.WriteImage(img_s, f"{IMAGE_DIR}/segmentations_001.nii.gz")
# create plot
# json_path = './../scripts/metrics.json'
# f = open(json_path)
# data = json.load(f)
# x = data['fpr']
# y = data['tpr']
# auroc = data['auroc']
# plt.plot(x,y)

68
scripts/5.Visualize_frocs.py Executable file
View File

@ -0,0 +1,68 @@
import sys
sys.path.append('./../code')
from utils_quintin import *
import matplotlib.pyplot as plt
import argparse
parser = argparse.ArgumentParser(
description='Visualise froc results')
parser.add_argument('-saveas',
help='')
parser.add_argument('-comparison',
help='')
parser.add_argument('--experiment', '-s',
metavar='[series_name]', required=True, nargs='+',
help='List of series to include, must correspond with' +
"path files in ./data/")
args = parser.parse_args()
if args.comparison:
colors = ['r','r','b','b','g','g']
plot_type = ['-','--','-','--','-','--']
else:
colors = ['r','b','g','k']
plot_type = ['-','-','-','-']
experiments = args.experiment
print(experiments)
experiment_path = []
experiment_metrics = {}
auroc = []
for idx in range(len(args.experiment)):
experiment_path = f'./../train_output/{experiments[idx]}/froc_metrics.yml'
experiment_metrics = read_yaml_to_dict(experiment_path)
auroc.append(round(experiment_metrics['auroc'],3))
plt.figure(1)
plt.plot(experiment_metrics["FP_per_case"], experiment_metrics["sensitivity"],color=colors[idx],linestyle=plot_type[idx])
plt.figure(2)
plt.plot(experiment_metrics["fpr"], experiment_metrics["tpr"],color=colors[idx],linestyle=plot_type[idx])
print(auroc)
experiments = [exp.replace('train_10h_', '') for exp in experiments]
experiments = [exp.replace('train_n0.001_', '') for exp in experiments]
experiments = [exp.replace('_', ' ') for exp in experiments]
# experiments = ['10% noise','1% noise','0.1% noise','0.05% noise']
plt.figure(1)
plt.title('fROC curve')
plt.xlabel('False positive per case')
plt.ylabel('Sensitivity')
plt.legend(experiments,loc='lower right')
plt.xlim([0,3])
plt.ylim([0,1])
plt.yticks([0,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1])
plt.grid()
plt.savefig(f"./../train_output/fROC_{args.saveas}.png", dpi=300)
concat_func = lambda x,y: x + " (" + str(y) + ")"
experiments_auroc = list(map(concat_func,experiments,auroc)) # list the map function
plt.figure(2)
plt.title('ROC curve')
plt.legend(experiments_auroc,loc='lower right')
plt.xlabel('False positive rate')
plt.ylabel('True positive rate')
plt.grid()
plt.savefig(f"./../train_output/ROC_{args.saveas}.png", dpi=300)

108
scripts/6.saliency_map.py Executable file
View File

@ -0,0 +1,108 @@
import sys
from os import path
import SimpleITK as sitk
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.models import load_model
from focal_loss import BinaryFocalLoss
import json
import matplotlib.pyplot as plt
import numpy as np
from sfransen.Saliency.base import *
from sfransen.Saliency.integrated_gradients import *
# from tensorflow.keras.vis.visualization import visualize_saliency
sys.path.append('./../code')
from utils_quintin import *
sys.path.append('./../code/DWI_exp')
# from preprocessing_function import preprocess
from sfransen.DWI_exp import preprocess
print("done step 1")
from sfransen.DWI_exp.helpers import *
# from helpers import *
from callbacks import dice_coef
sys.path.append('./../code/FROC')
from blob_preprocess import *
from cal_froc_from_np import *
quit()
# train_10h_t2_b50_b400_b800_b1400_adc
SERIES = ['t2','b50','b400','b800','b1400','adc']
MODEL_PATH = f'./../train_output/train_10h_t2_b50_b400_b800_b1400_adc/models/train_10h_t2_b50_b400_b800_b1400_adc.h5'
YAML_DIR = f'./../train_output/train_10h_t2_b50_b400_b800_b1400_adc'
################ constants ############
DATA_DIR = "./../data/Nijmegen paths/"
TARGET_SPACING = (0.5, 0.5, 3)
INPUT_SHAPE = (192, 192, 24, len(SERIES))
IMAGE_SHAPE = INPUT_SHAPE[:3]
# import val_indx
# DATA_SPLIT_INDEX = read_yaml_to_dict('./../data/Nijmegen paths/train_val_test_idxs.yml')
# TEST_INDEX = DATA_SPLIT_INDEX['val_set0']
experiment_path = f'./../train_output/train_10h_t2_b50_b400_b800_b1400_adc/froc_metrics.yml'
experiment_metrics = read_yaml_to_dict(experiment_path)
DATA_SPLIT_INDEX = read_yaml_to_dict('./../data/Nijmegen paths/train_val_test_idxs.yml')
TEST_INDEX = DATA_SPLIT_INDEX['val_set0']
top_10_idx = np.argsort(experiment_metrics['roc_pred'])[-10:]
TEST_INDEX = [TEST_INDEX[i] for i in top_10_idx]
########## load images ##############
images, image_paths = {s: [] for s in SERIES}, {}
segmentations = []
print_(f"> Loading images into RAM...")
for s in SERIES:
with open(path.join(DATA_DIR, f"{s}.txt"), 'r') as f:
image_paths[s] = [l.strip() for l in f.readlines()]
with open(path.join(DATA_DIR, f"seg.txt"), 'r') as f:
seg_paths = [l.strip() for l in f.readlines()]
num_images = len(seg_paths)
# Read and preprocess each of the paths for each SERIES, and the segmentations.
for img_idx in TEST_INDEX[:5]: #for less images
img_s = {s: sitk.ReadImage(image_paths[s][img_idx], sitk.sitkFloat32)
for s in SERIES}
seg_s = sitk.ReadImage(seg_paths[img_idx], sitk.sitkFloat32)
img_n, seg_n = preprocess(img_s, seg_s,
shape=IMAGE_SHAPE, spacing=TARGET_SPACING)
for seq in img_n:
images[seq].append(img_n[seq])
segmentations.append(seg_n)
images_list = [images[s] for s in images.keys()]
images_list = np.transpose(images_list, (1, 2, 3, 4, 0))
########### load module ##################
dependencies = {
'dice_coef': dice_coef
}
reconstructed_model = load_model(MODEL_PATH, custom_objects=dependencies)
# reconstructed_model.layers[-1].activation = tf.keras.activations.linear
print('START prediction')
ig = IntegratedGradients(reconstructed_model)
saliency_map = []
for img_idx in range(len(images_list)):
# input_img = np.resize(images_list[img_idx],(1,48,48,8,8))
input_img = np.resize(images_list[img_idx],(1,192,192,24,len(SERIES)))
saliency_map.append(ig.get_mask(input_img).numpy())
print("size saliency map is:",np.shape(saliency_map))
np.save('saliency',saliency_map)
# Christian Roest, [11-3-2022 15:30]
# input_img heeft dimensies (1, 48, 48, 8, 8)
# reconstructed_model.summary(line_length=120)
# make predictions on all val_indx
print('START saliency')
# predictions_blur = reconstructed_model.predict(images_list, batch_size=1)

90
scripts/7.Visualize_saliency.py Executable file
View File

@ -0,0 +1,90 @@
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.cm as cm
heatmap = np.load('saliency.npy')
print(np.shape(heatmap))
heatmap = np.squeeze(heatmap)
print(np.shape(heatmap))
### take average over 5 #########
heatmap = np.mean(abs(heatmap),axis=0)
print(np.shape(heatmap))
SERIES = ['t2','b50','b400','b800','b1400','adc']
fig, axes = plt.subplots(1,6)
max_value = np.amax(heatmap)
pri
min_value = np.amin(heatmap)
# vmin vmax van hele heatmap voor scaling in imshow
# cmap naar grey
im = axes[0].imshow(np.squeeze(heatmap[:,:,12,0]))
axes[1].imshow(np.squeeze(heatmap[:,:,12,1]), vmin=min_value, vmax=max_value)
axes[2].imshow(np.squeeze(heatmap[:,:,12,2]), vmin=min_value, vmax=max_value)
axes[3].imshow(np.squeeze(heatmap[:,:,12,3]), vmin=min_value, vmax=max_value)
axes[4].imshow(np.squeeze(heatmap[:,:,12,4]), vmin=min_value, vmax=max_value)
axes[5].imshow(np.squeeze(heatmap[:,:,12,5]), vmin=min_value, vmax=max_value)
axes[0].set_title("t2")
axes[1].set_title("b50")
axes[2].set_title("b400")
axes[3].set_title("b800")
axes[4].set_title("b1400")
axes[5].set_title("adc")
cbar = fig.colorbar(im, ax=axes.ravel().tolist(), shrink=0.5, orientation='horizontal')
cbar.set_ticks([-0.1,0,0.1])
cbar.set_ticklabels(['less importance', '0', 'important'])
fig.suptitle('Average saliency maps over the 5 highest predictions', fontsize=16)
plt.show()
quit()
#take one image out
heatmap = np.squeeze(heatmap[0])
import numpy as np
import matplotlib.pyplot as plt
# Fixing random state for reproducibility
np.random.seed(19680801)
class IndexTracker:
def __init__(self, ax, X):
self.ax = ax
ax.set_title('use scroll wheel to navigate images')
self.X = X
rows, cols, self.slices = X.shape
self.ind = self.slices//2
self.im = ax.imshow(self.X[:, :, self.ind], cmap='jet')
self.update()
def on_scroll(self, event):
print("%s %s" % (event.button, event.step))
if event.button == 'up':
self.ind = (self.ind + 1) % self.slices
else:
self.ind = (self.ind - 1) % self.slices
self.update()
def update(self):
self.im.set_data(self.X[:, :, self.ind])
self.ax.set_ylabel('slice %s' % self.ind)
self.im.axes.figure.canvas.draw()
plt.figure(0)
fig, ax = plt.subplots(1, 1)
tracker = IndexTracker(ax, heatmap[:,:,:,5])
fig.canvas.mpl_connect('scroll_event', tracker.on_scroll)
plt.show()
plt.figure(1)
fig, ax = plt.subplots(1, 1)
tracker = IndexTracker(ax, heatmap[:,:,:,3])
fig.canvas.mpl_connect('scroll_event', tracker.on_scroll)
plt.show()

59
scripts/8.Visualize_training.py Executable file
View File

@ -0,0 +1,59 @@
import matplotlib.pyplot as plt
import pandas as pd
import glob
import argparse
import os
#create parser
def parse_input_args():
parser = argparse.ArgumentParser(description='Parse arguments for training a Reconstruction model')
parser.add_argument('train_out_dir',
type=str,
help='Directory name in train_output dir of the desired experiment folder. There should be a .csv file in this directory with train statistics.')
args = parser.parse_args()
return args
args = parse_input_args()
print(f"Plotting {args}")
# find csv file
# csv = glob.glob(f"train_output/{args.train_out_dir}/*.csv")[0]
folder_input = args.train_out_dir
# load csv file
df = pd.read_csv(f'{folder_input}')
# read csv file
for metric in df:
# if not metric == 'epoch':
if metric == 'loss' or metric == 'val_loss':
plt.plot(df['epoch'], df[metric], label=metric)
plt.ylim(ymin=0,ymax=0.01)
folder, csvfile = os.path.split(args.train_out_dir)
root, experiment = os.path.split(os.path.split(folder)[0])
plt.title(experiment)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.grid()
plt.legend()
plt.savefig(f"{folder}/{experiment}.png")
plt.clf()
plt.close()
print(folder+".png")
print(f"\nsaved figure to {folder}")
# van yaml inladen 'loss'
# vanuit utils > dict to yaml
# csv viewer extension
# course .venv/bin/activate
# loop over alles en if metric then 1-metric
# kleur codering color=[]