add scripts
This commit is contained in:
parent
8cd3d865da
commit
be5b392456
|
@ -143,7 +143,6 @@ cython_debug/
|
|||
/old_code/
|
||||
/data/
|
||||
/job_scripts/
|
||||
/scripts/
|
||||
/temp/
|
||||
/slurms/
|
||||
*.out
|
||||
|
|
|
@ -0,0 +1,203 @@
|
|||
import multiprocessing
|
||||
from os import path
|
||||
import argparse
|
||||
import time
|
||||
from datetime import datetime
|
||||
import sys
|
||||
# sys.path.append('./../code')
|
||||
# from utils_quintin import *
|
||||
# from sfransen.utils_quintin import *
|
||||
# sys.path.append('./../code/DWI_exp')
|
||||
# from callbacks import IntermediateImages, dice_coef
|
||||
# from callbacks import RocCallback
|
||||
from sfransen.DWI_exp import IntermediateImages, dice_coef
|
||||
from sfransen.DWI_exp.preprocessing_function import preprocess
|
||||
import yaml
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
import tensorflow as tf
|
||||
tf.compat.v1.disable_eager_execution()
|
||||
|
||||
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, CSVLogger
|
||||
from tensorflow.keras.optimizers import Adam
|
||||
from sklearn.model_selection import KFold
|
||||
|
||||
from sfransen.DWI_exp.helpers import *
|
||||
from sfransen.DWI_exp.batchgenerator import BatchGenerator
|
||||
|
||||
from sfransen.DWI_exp.unet import build_dual_attention_unet
|
||||
from focal_loss import BinaryFocalLoss
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Train a U-Net model for segmentation/detection tasks.' +
|
||||
'using cross-validation.')
|
||||
parser.add_argument('--series', '-s',
|
||||
metavar='[series_name]', required=True, nargs='+',
|
||||
help='List of series to include, must correspond with' +
|
||||
"path files in ./data/")
|
||||
parser.add_argument('-experiment',
|
||||
help='add experiment title to store the files correctly: test_b50_b400_b800'
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Determine the number of input series
|
||||
num_series = len(args.series)
|
||||
|
||||
# Identify this job by the series included in the training
|
||||
# Output folder will have this name, e.g.: b0_b50_b100
|
||||
# JOB_NAME = '_'.join(args.series)
|
||||
JOB_NAME = args.experiment
|
||||
DATA_DIR = "./../data/Nijmegen paths/"
|
||||
# DATA_DIR = "./../data/new/"
|
||||
# PROJECT_DIR = f"/data/pca-rad/sfransen/train_output/{args.experiment}"
|
||||
PROJECT_DIR = f"/data/pca-rad/sfransen/train_output/{JOB_NAME}"
|
||||
# 2 x 2mm2 in-plane resolution, 3.6mm slice thickness
|
||||
TARGET_SPACING = (0.5, 0.5, 3)
|
||||
INPUT_SHAPE = (192, 192, 24, num_series) #(64, 64, 20, num_series)
|
||||
IMAGE_SHAPE = INPUT_SHAPE[:3]
|
||||
OUTPUT_SHAPE = (192, 192, 24, 1) # One output channel (segmentation)
|
||||
|
||||
# Hyperparameters
|
||||
FOCAL_LOSS_GAMMA = 2
|
||||
INITIAL_LEARNING_RATE = 1e-4
|
||||
MAX_EPOCHS = 600
|
||||
EARLY_STOPPING = 50
|
||||
# increase batch size
|
||||
BATCH_SIZE = 12
|
||||
MODEL_SELECTION_METRIC = 'val_loss'
|
||||
MODEL_SELECTION_DIRECTION = "min" # Change to 'max' if higher value is better
|
||||
EARLY_STOPPING_METRIC = 'val_loss'
|
||||
EARLY_STOPPING_DIRECTION = "min" # Change to 'max' if higher value is better
|
||||
|
||||
# Training configuration
|
||||
# add metric ROC_AUC
|
||||
TRAINING_METRICS = ["binary_crossentropy", "binary_accuracy", dice_coef]
|
||||
loss = BinaryFocalLoss(gamma=FOCAL_LOSS_GAMMA)
|
||||
optimizer = Adam(learning_rate=INITIAL_LEARNING_RATE)
|
||||
|
||||
# Create folder structure in the output directory
|
||||
if path.exists(PROJECT_DIR):
|
||||
prepare_project_dir(PROJECT_DIR+'_(2)')
|
||||
else:
|
||||
prepare_project_dir(PROJECT_DIR)
|
||||
|
||||
#save params to yaml
|
||||
params = {
|
||||
"focal_loss_gamma": FOCAL_LOSS_GAMMA,
|
||||
"initial_learning_rate": INITIAL_LEARNING_RATE,
|
||||
"max_epochs": MAX_EPOCHS,
|
||||
"MODEL_SELECTION_METRIC": MODEL_SELECTION_METRIC,
|
||||
"EARLY_STOPPING_METRIC": EARLY_STOPPING_METRIC,
|
||||
"train_output_dir": PROJECT_DIR,
|
||||
"batch_size": BATCH_SIZE,
|
||||
"optimizer": optimizer,
|
||||
"loss": loss,
|
||||
"datetime": print(datetime.now().strftime("%Y-%m-%d"))}
|
||||
dump_dict_to_yaml(params, f"{PROJECT_DIR}", filename=f"params")
|
||||
|
||||
# Build the U-Net model
|
||||
detection_model = build_dual_attention_unet(INPUT_SHAPE)
|
||||
detection_model.summary(line_length=120)
|
||||
|
||||
# Load all numpy images into RAM
|
||||
images, image_paths = {s: [] for s in args.series}, {}
|
||||
segmentations = []
|
||||
print_(f"> Loading images into RAM...")
|
||||
|
||||
# Read the image paths from the data directory.
|
||||
# Texts files are expected to have the name "[series_name].txt"
|
||||
for s in args.series:
|
||||
with open(path.join(DATA_DIR, f"{s}.txt"), 'r') as f:
|
||||
image_paths[s] = [l.strip() for l in f.readlines()]
|
||||
with open(path.join(DATA_DIR, f"seg.txt"), 'r') as f:
|
||||
seg_paths = [l.strip() for l in f.readlines()]
|
||||
num_images = len(seg_paths)
|
||||
|
||||
# Read and preprocess each of the paths for each series, and the segmentations.
|
||||
for img_idx in tqdm(range(num_images)): # [:40]):for less images
|
||||
img_s = {s: sitk.ReadImage(image_paths[s][img_idx], sitk.sitkFloat32)
|
||||
for s in args.series}
|
||||
seg_s = sitk.ReadImage(seg_paths[img_idx], sitk.sitkFloat32)
|
||||
img_n, seg_n = preprocess(img_s, seg_s,
|
||||
shape=IMAGE_SHAPE, spacing=TARGET_SPACING)
|
||||
for seq in img_n:
|
||||
images[seq].append(img_n[seq])
|
||||
segmentations.append(seg_n)
|
||||
|
||||
# Split train and validation
|
||||
# We use KFold to split the data, but we don't actually do cross validation, we
|
||||
# just use it to split the data 1:9.
|
||||
# kfold = KFold(10, shuffle=True, random_state=123)
|
||||
# train_idxs, valid_idxs = list(kfold.split(segmentations))[0]
|
||||
# train_idxs = list(train_idxs)
|
||||
# valid_idxs = list(valid_idxs)
|
||||
|
||||
yml_paths = read_yaml_to_dict('./../data/Nijmegen paths/train_val_test_idxs.yml')
|
||||
train_idxs = yml_paths['train_set0']
|
||||
valid_idxs = yml_paths['val_set0']
|
||||
|
||||
|
||||
detection_model.compile(
|
||||
optimizer=optimizer,
|
||||
loss=loss,
|
||||
metrics=TRAINING_METRICS
|
||||
)
|
||||
|
||||
train_generator = BatchGenerator(images, segmentations,
|
||||
sequences=args.series,
|
||||
shape=IMAGE_SHAPE,
|
||||
indexes=train_idxs,
|
||||
batch_size=BATCH_SIZE,
|
||||
shuffle=True,
|
||||
augmentation_function=augment
|
||||
)
|
||||
|
||||
valid_generator = get_generator(images, segmentations,
|
||||
sequences=args.series,
|
||||
shape=IMAGE_SHAPE,
|
||||
indexes=valid_idxs,
|
||||
batch_size=None,
|
||||
shuffle=False,
|
||||
augmentation=None
|
||||
)
|
||||
valid_data = next(valid_generator)
|
||||
print_(f"The shape of valid_data input = {np.shape(valid_data[0])}")
|
||||
print_(f"The shape of valid_data label = {np.shape(valid_data[1])}")
|
||||
|
||||
callbacks = [
|
||||
EarlyStopping(
|
||||
monitor=EARLY_STOPPING_METRIC,
|
||||
mode=EARLY_STOPPING_DIRECTION,
|
||||
patience=EARLY_STOPPING,
|
||||
verbose=1),
|
||||
ModelCheckpoint(
|
||||
filepath=path.join(PROJECT_DIR, "models", JOB_NAME + ".h5"),
|
||||
monitor=MODEL_SELECTION_METRIC,
|
||||
mode=MODEL_SELECTION_DIRECTION,
|
||||
verbose=1,
|
||||
save_best_only=True),
|
||||
# ModelCheckpoint(
|
||||
# filepath=path.join(PROJECT_DIR, "models_dice", JOB_NAME + ".h5"),
|
||||
# monitor='val_dice_coef',
|
||||
# mode='max',
|
||||
# verbose=0,
|
||||
# save_best_only=True),
|
||||
CSVLogger(
|
||||
filename=path.join(PROJECT_DIR, "logs", f"{JOB_NAME}.csv")),
|
||||
IntermediateImages(
|
||||
validation_set=valid_data,
|
||||
sequences=args.series,
|
||||
prefix=path.join(PROJECT_DIR, "output", JOB_NAME),
|
||||
num_images=25)
|
||||
# RocCallback(
|
||||
# validation_set=valid_data,
|
||||
# num_images=25)
|
||||
]
|
||||
|
||||
detection_model.fit(train_generator,
|
||||
validation_data = valid_data,
|
||||
steps_per_epoch = len(train_idxs) // BATCH_SIZE,
|
||||
epochs = MAX_EPOCHS,
|
||||
callbacks = callbacks,
|
||||
verbose = 1,
|
||||
)
|
|
@ -0,0 +1,107 @@
|
|||
import argparse
|
||||
import random
|
||||
import sys
|
||||
sys.path.append('./../code')
|
||||
from utils_quintin import list_from_file, dump_dict_to_yaml, print_p
|
||||
|
||||
################################ README ######################################
|
||||
# NEW - This script will create indexes for the training, validation and test
|
||||
# sets. Create a yaml file with 10 sets of indexes for train, val and test. So
|
||||
# that they can be used for 10 fold cross validation when needed. The filename
|
||||
# should contain an integer, which is the seed used to generate the indexes. The
|
||||
# filename also indicates the type of split used.
|
||||
# Output structure:
|
||||
# trainSet0: [indexes]
|
||||
# valSet0: [indexes]
|
||||
# testSet0: [indexes]
|
||||
# trainSet1: [indexes]
|
||||
# valSet1: [indexes]
|
||||
# testSet1: [indexes]
|
||||
# etc...
|
||||
|
||||
|
||||
################################ PARSER ########################################
|
||||
|
||||
|
||||
def parse_input_args():
|
||||
parser = argparse.ArgumentParser(description='Parse arguments for splitting training, validation and the test set.')
|
||||
|
||||
parser.add_argument('-n',
|
||||
'--num_folds',
|
||||
type=int,
|
||||
default=1,
|
||||
help='The number of folds in total. This amount of index sets will be created.')
|
||||
|
||||
parser.add_argument('-p',
|
||||
'--path_to_paths_file',
|
||||
type=str,
|
||||
default="./../data/Nijmegen paths/seg.txt",
|
||||
help='Path to the .txt file containting paths to the nifti files.')
|
||||
|
||||
parser.add_argument('-s',
|
||||
'--split',
|
||||
type=str,
|
||||
default="80/10/10",
|
||||
help='Train/validation/test split in percentages.')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# split the given split string.
|
||||
args.p_train = int(args.split.split('/')[0])
|
||||
args.p_val = int(args.split.split('/')[1])
|
||||
args.p_test = int(args.split.split('/')[2])
|
||||
|
||||
assert args.p_train + args.p_val + args.p_test == 100, "The train, val, test split to sum to 100%."
|
||||
|
||||
return args
|
||||
|
||||
|
||||
################################################################################
|
||||
SEED = 3478
|
||||
|
||||
if __name__ == '__main__':
|
||||
print_p('\n\nMaking Train - Validation - Test indexes based.')
|
||||
|
||||
# Parse some arguments
|
||||
args = parse_input_args()
|
||||
print_p(args)
|
||||
|
||||
# Read the amount of observations/subjects in the data.
|
||||
t2_paths = list_from_file(args.path_to_paths_file)
|
||||
num_obs = len(t2_paths)
|
||||
print_p(f"Number of observations in {args.path_to_paths_file}: {len(t2_paths)}")
|
||||
|
||||
# Create cutoff points for training, validation and test set.
|
||||
train_cutoff = int(args.p_train/100 * num_obs)
|
||||
val_cutoff = int(args.p_val/100 * num_obs) + train_cutoff
|
||||
test_cutoff = int(args.p_test/100 * num_obs) + val_cutoff
|
||||
print(f"\ncutoffs: {train_cutoff}, {val_cutoff}, {test_cutoff}")
|
||||
|
||||
# Create dict that will hold all the data
|
||||
data_dict = {}
|
||||
data_dict["init_seed"] = SEED
|
||||
data_dict["split"] = args.split
|
||||
|
||||
# loop over the amount of folds, that many sets will be created in a yaml file.
|
||||
for set_idx in range(args.num_folds):
|
||||
|
||||
# Set new seed first
|
||||
random.seed(SEED + set_idx)
|
||||
|
||||
# shuffle the indexes
|
||||
indexes = list(range(num_obs))
|
||||
random.shuffle(indexes)
|
||||
|
||||
train_idxs = indexes[:train_cutoff]
|
||||
val_idxs = indexes[train_cutoff:val_cutoff]
|
||||
test_idxs = indexes[val_cutoff:test_cutoff]
|
||||
|
||||
data_dict[f"train_set{set_idx}"] = train_idxs
|
||||
data_dict[f"val_set{set_idx}"] = val_idxs
|
||||
data_dict[f"test_set{set_idx}"] = test_idxs
|
||||
|
||||
for key in data_dict:
|
||||
if type(data_dict[key]) == list:
|
||||
print(f"{key}: {len(data_dict[key])}")
|
||||
|
||||
dump_dict_to_yaml(data_dict, "./../data", filename=f"train_val_test_idxs", verbose=False)
|
|
@ -0,0 +1,217 @@
|
|||
import sys
|
||||
from os import path
|
||||
import SimpleITK as sitk
|
||||
import tensorflow as tf
|
||||
from tensorflow.keras.models import load_model
|
||||
from focal_loss import BinaryFocalLoss
|
||||
import json
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import multiprocessing
|
||||
from functools import partial
|
||||
|
||||
sys.path.append('./../code')
|
||||
from utils_quintin import *
|
||||
|
||||
sys.path.append('./../code/DWI_exp')
|
||||
from helpers import *
|
||||
from preprocessing_function import preprocess
|
||||
from callbacks import dice_coef
|
||||
|
||||
sys.path.append('./../code/FROC')
|
||||
from blob_preprocess import *
|
||||
from cal_froc_from_np import *
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Train a U-Net model for segmentation/detection tasks.' +
|
||||
'using cross-validation.')
|
||||
parser.add_argument('--series', '-s',
|
||||
metavar='[series_name]', required=True, nargs='+',
|
||||
help='List of series to include, must correspond with' +
|
||||
"path files in ./data/")
|
||||
args = parser.parse_args()
|
||||
|
||||
######## parsed inputs #############
|
||||
# SERIES = ['b50', 'b400', 'b800'] #can be parsed
|
||||
SERIES = args.series
|
||||
series_ = '_'.join(args.series)
|
||||
# Import model
|
||||
# MODEL_PATH = f'./../train_output/train_10h_{series_}/models/train_10h_{series_}.h5'
|
||||
# YAML_DIR = f'./../train_output/train_10h_{series_}'
|
||||
MODEL_PATH = f'./../train_output/train_n0.001_{series_}/models/train_n0.001_{series_}.h5'
|
||||
print(MODEL_PATH)
|
||||
YAML_DIR = f'./../train_output/train_n0.001_{series_}'
|
||||
################ constants ############
|
||||
DATA_DIR = "./../data/Nijmegen paths/"
|
||||
TARGET_SPACING = (0.5, 0.5, 3)
|
||||
INPUT_SHAPE = (192, 192, 24, len(SERIES))
|
||||
IMAGE_SHAPE = INPUT_SHAPE[:3]
|
||||
|
||||
# import val_indx
|
||||
DATA_SPLIT_INDEX = read_yaml_to_dict('./../data/Nijmegen paths/train_val_test_idxs.yml')
|
||||
TEST_INDEX = DATA_SPLIT_INDEX['val_set0']
|
||||
|
||||
|
||||
########## load images ##############
|
||||
images, image_paths = {s: [] for s in SERIES}, {}
|
||||
segmentations = []
|
||||
print_(f"> Loading images into RAM...")
|
||||
|
||||
for s in SERIES:
|
||||
with open(path.join(DATA_DIR, f"{s}.txt"), 'r') as f:
|
||||
image_paths[s] = [l.strip() for l in f.readlines()]
|
||||
with open(path.join(DATA_DIR, f"seg.txt"), 'r') as f:
|
||||
seg_paths = [l.strip() for l in f.readlines()]
|
||||
num_images = len(seg_paths)
|
||||
|
||||
# Read and preprocess each of the paths for each SERIES, and the segmentations.
|
||||
|
||||
from typing import List
|
||||
|
||||
|
||||
def load_images(
|
||||
image_paths: str,
|
||||
seq: str,
|
||||
target_shape: List[int],
|
||||
target_space = List[float]):
|
||||
|
||||
img_s = sitk.ReadImage(image_paths, sitk.sitkFloat32)
|
||||
|
||||
#resample
|
||||
mri_tra_s = resample(img_s,
|
||||
min_shape=target_shape,
|
||||
method=sitk.sitkNearestNeighbor,
|
||||
new_spacing=target_space)
|
||||
|
||||
#center crop
|
||||
mri_tra_s = center_crop(mri_tra_s, shape=target_shape)
|
||||
#normalize
|
||||
if seq != 'seg':
|
||||
filter = sitk.NormalizeImageFilter()
|
||||
mri_tra_s = filter.Execute(mri_tra_s)
|
||||
else:
|
||||
filter = sitk.BinaryThresholdImageFilter()
|
||||
filter.SetLowerThreshold(1.0)
|
||||
mri_tra_s = filter.Execute(mri_tra_s)
|
||||
|
||||
return sitk.GetArrayFromImage(mri_tra_s).T
|
||||
|
||||
N_CPUS = 12
|
||||
pool = multiprocessing.Pool(processes=N_CPUS)
|
||||
partial_f = partial(load_images,
|
||||
seq = 'images',
|
||||
target_shape=IMAGE_SHAPE,
|
||||
target_space = TARGET_SPACING)
|
||||
|
||||
images_2 = []
|
||||
for s in SERIES:
|
||||
image_paths_seq = image_paths[s]
|
||||
image_paths_index = np.asarray(image_paths_seq)[TEST_INDEX]
|
||||
data_list = pool.map(partial_f,image_paths_index)
|
||||
data = np.stack(data_list, axis=0)
|
||||
images_2.append(data)
|
||||
# print(s)
|
||||
# print(np.shape(data))
|
||||
print(np.shape(images_2))
|
||||
|
||||
partial_f = partial(load_images,
|
||||
seq = 'seg',
|
||||
target_shape=IMAGE_SHAPE,
|
||||
target_space = TARGET_SPACING)
|
||||
seg_paths_index = np.asarray(seg_paths)[TEST_INDEX]
|
||||
data_list = pool.map(partial_f,seg_paths_index)
|
||||
segmentations = np.stack(data_list, axis=0)
|
||||
|
||||
# print("segmentations pool",np.shape(segmentations_2))
|
||||
|
||||
# for img_idx in TEST_INDEX: #for less images
|
||||
# img_s = {s: sitk.ReadImage(image_paths[s][img_idx], sitk.sitkFloat32)
|
||||
# for s in SERIES}
|
||||
# seg_s = sitk.ReadImage(seg_paths[img_idx], sitk.sitkFloat32)
|
||||
# img_n, seg_n = preprocess(img_s, seg_s,
|
||||
# shape=IMAGE_SHAPE, spacing=TARGET_SPACING)
|
||||
# for seq in img_n:
|
||||
# images[seq].append(img_n[seq])
|
||||
# segmentations.append(seg_n)
|
||||
|
||||
# print("segmentations old",np.shape(segmentations))
|
||||
|
||||
# # from dict to list
|
||||
# # images_list = [img nmbr, [INPUT_SHAPE]]
|
||||
# images_list = [images[s] for s in images.keys()]
|
||||
# images_list = np.transpose(images_list, (1, 2, 3, 4, 0))
|
||||
images_list = np.transpose(images_2, (1, 2, 3, 4, 0))
|
||||
|
||||
print("images size ",np.shape(images_list))
|
||||
print("size segmentation",np.shape(segmentations))
|
||||
# print("images size pool",np.shape(images_list_2))
|
||||
|
||||
import os
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = "2"
|
||||
########### load module ##################
|
||||
print(' >>>>>>> LOAD MODEL <<<<<<<<<')
|
||||
|
||||
dependencies = {
|
||||
'dice_coef': dice_coef
|
||||
}
|
||||
reconstructed_model = load_model(MODEL_PATH, custom_objects=dependencies)
|
||||
# reconstructed_model.summary(line_length=120)
|
||||
|
||||
# make predictions on all val_indx
|
||||
print(' >>>>>>> START prediction <<<<<<<<<')
|
||||
predictions_blur = reconstructed_model.predict(images_list, batch_size=1)
|
||||
|
||||
# print("The shape of the predictions list is: ",np.shape(predictions_blur))
|
||||
# print(type(predictions))
|
||||
# np.save('predictions',predictions)
|
||||
|
||||
# preprocess predictions by removing the blur and making individual blobs
|
||||
print('>>>>>>>> START preprocess')
|
||||
|
||||
def move_dims(arr):
|
||||
# UMCG numpy dimensions convention: dims = (batch, width, heigth, depth)
|
||||
# Joeran numpy dimensions convention: dims = (batch, depth, heigth, width)
|
||||
arr = np.moveaxis(arr, 3, 1) # Joeran has his numpy arrays ordered differently.
|
||||
arr = np.moveaxis(arr, 3, 2)
|
||||
return arr
|
||||
|
||||
# Joeran has his numpy arrays ordered differently.
|
||||
predictions_blur = move_dims(np.squeeze(predictions_blur))
|
||||
segmentations = move_dims(np.squeeze(segmentations))
|
||||
predictions = [preprocess_softmax(pred, threshold="dynamic")[0] for pred in predictions_blur]
|
||||
|
||||
# Remove outer edges
|
||||
zeros = np.zeros(np.shape(predictions))
|
||||
test = np.squeeze(predictions)[:,:,2:190,2:190]
|
||||
zeros[:,:,2:190,2:190] = test
|
||||
predictions = zeros
|
||||
|
||||
# perform Froc
|
||||
metrics = evaluate(y_true=segmentations, y_pred=predictions)
|
||||
dump_dict_to_yaml(metrics, YAML_DIR, "froc_metrics", verbose=True)
|
||||
|
||||
|
||||
|
||||
# save one image
|
||||
IMAGE_DIR = f'./../train_output/train_10h_{series_}'
|
||||
img_s = sitk.GetImageFromArray(predictions_blur[3].squeeze())
|
||||
sitk.WriteImage(img_s, f"{IMAGE_DIR}/predictions_blur_001.nii.gz")
|
||||
|
||||
img_s = sitk.GetImageFromArray(predictions[3].squeeze())
|
||||
sitk.WriteImage(img_s, f"{IMAGE_DIR}/predictions_001.nii.gz")
|
||||
|
||||
img_s = sitk.GetImageFromArray(segmentations[3].squeeze())
|
||||
sitk.WriteImage(img_s, f"{IMAGE_DIR}/segmentations_001.nii.gz")
|
||||
|
||||
|
||||
|
||||
|
||||
# create plot
|
||||
# json_path = './../scripts/metrics.json'
|
||||
# f = open(json_path)
|
||||
# data = json.load(f)
|
||||
# x = data['fpr']
|
||||
# y = data['tpr']
|
||||
# auroc = data['auroc']
|
||||
# plt.plot(x,y)
|
|
@ -0,0 +1,68 @@
|
|||
import sys
|
||||
sys.path.append('./../code')
|
||||
from utils_quintin import *
|
||||
import matplotlib.pyplot as plt
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Visualise froc results')
|
||||
parser.add_argument('-saveas',
|
||||
help='')
|
||||
parser.add_argument('-comparison',
|
||||
help='')
|
||||
parser.add_argument('--experiment', '-s',
|
||||
metavar='[series_name]', required=True, nargs='+',
|
||||
help='List of series to include, must correspond with' +
|
||||
"path files in ./data/")
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.comparison:
|
||||
colors = ['r','r','b','b','g','g']
|
||||
plot_type = ['-','--','-','--','-','--']
|
||||
else:
|
||||
colors = ['r','b','g','k']
|
||||
plot_type = ['-','-','-','-']
|
||||
|
||||
experiments = args.experiment
|
||||
print(experiments)
|
||||
experiment_path = []
|
||||
experiment_metrics = {}
|
||||
auroc = []
|
||||
for idx in range(len(args.experiment)):
|
||||
experiment_path = f'./../train_output/{experiments[idx]}/froc_metrics.yml'
|
||||
experiment_metrics = read_yaml_to_dict(experiment_path)
|
||||
auroc.append(round(experiment_metrics['auroc'],3))
|
||||
|
||||
plt.figure(1)
|
||||
plt.plot(experiment_metrics["FP_per_case"], experiment_metrics["sensitivity"],color=colors[idx],linestyle=plot_type[idx])
|
||||
|
||||
plt.figure(2)
|
||||
plt.plot(experiment_metrics["fpr"], experiment_metrics["tpr"],color=colors[idx],linestyle=plot_type[idx])
|
||||
|
||||
print(auroc)
|
||||
experiments = [exp.replace('train_10h_', '') for exp in experiments]
|
||||
experiments = [exp.replace('train_n0.001_', '') for exp in experiments]
|
||||
experiments = [exp.replace('_', ' ') for exp in experiments]
|
||||
# experiments = ['10% noise','1% noise','0.1% noise','0.05% noise']
|
||||
|
||||
plt.figure(1)
|
||||
plt.title('fROC curve')
|
||||
plt.xlabel('False positive per case')
|
||||
plt.ylabel('Sensitivity')
|
||||
plt.legend(experiments,loc='lower right')
|
||||
plt.xlim([0,3])
|
||||
plt.ylim([0,1])
|
||||
plt.yticks([0,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1])
|
||||
plt.grid()
|
||||
plt.savefig(f"./../train_output/fROC_{args.saveas}.png", dpi=300)
|
||||
|
||||
concat_func = lambda x,y: x + " (" + str(y) + ")"
|
||||
experiments_auroc = list(map(concat_func,experiments,auroc)) # list the map function
|
||||
|
||||
plt.figure(2)
|
||||
plt.title('ROC curve')
|
||||
plt.legend(experiments_auroc,loc='lower right')
|
||||
plt.xlabel('False positive rate')
|
||||
plt.ylabel('True positive rate')
|
||||
plt.grid()
|
||||
plt.savefig(f"./../train_output/ROC_{args.saveas}.png", dpi=300)
|
|
@ -0,0 +1,108 @@
|
|||
import sys
|
||||
from os import path
|
||||
import SimpleITK as sitk
|
||||
import tensorflow as tf
|
||||
from tensorflow import keras
|
||||
from tensorflow.keras.models import load_model
|
||||
from focal_loss import BinaryFocalLoss
|
||||
import json
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
|
||||
from sfransen.Saliency.base import *
|
||||
from sfransen.Saliency.integrated_gradients import *
|
||||
# from tensorflow.keras.vis.visualization import visualize_saliency
|
||||
|
||||
sys.path.append('./../code')
|
||||
from utils_quintin import *
|
||||
|
||||
sys.path.append('./../code/DWI_exp')
|
||||
# from preprocessing_function import preprocess
|
||||
from sfransen.DWI_exp import preprocess
|
||||
print("done step 1")
|
||||
from sfransen.DWI_exp.helpers import *
|
||||
# from helpers import *
|
||||
from callbacks import dice_coef
|
||||
|
||||
sys.path.append('./../code/FROC')
|
||||
from blob_preprocess import *
|
||||
from cal_froc_from_np import *
|
||||
|
||||
quit()
|
||||
# train_10h_t2_b50_b400_b800_b1400_adc
|
||||
SERIES = ['t2','b50','b400','b800','b1400','adc']
|
||||
MODEL_PATH = f'./../train_output/train_10h_t2_b50_b400_b800_b1400_adc/models/train_10h_t2_b50_b400_b800_b1400_adc.h5'
|
||||
YAML_DIR = f'./../train_output/train_10h_t2_b50_b400_b800_b1400_adc'
|
||||
################ constants ############
|
||||
DATA_DIR = "./../data/Nijmegen paths/"
|
||||
TARGET_SPACING = (0.5, 0.5, 3)
|
||||
INPUT_SHAPE = (192, 192, 24, len(SERIES))
|
||||
IMAGE_SHAPE = INPUT_SHAPE[:3]
|
||||
|
||||
# import val_indx
|
||||
# DATA_SPLIT_INDEX = read_yaml_to_dict('./../data/Nijmegen paths/train_val_test_idxs.yml')
|
||||
# TEST_INDEX = DATA_SPLIT_INDEX['val_set0']
|
||||
|
||||
experiment_path = f'./../train_output/train_10h_t2_b50_b400_b800_b1400_adc/froc_metrics.yml'
|
||||
experiment_metrics = read_yaml_to_dict(experiment_path)
|
||||
DATA_SPLIT_INDEX = read_yaml_to_dict('./../data/Nijmegen paths/train_val_test_idxs.yml')
|
||||
TEST_INDEX = DATA_SPLIT_INDEX['val_set0']
|
||||
|
||||
top_10_idx = np.argsort(experiment_metrics['roc_pred'])[-10:]
|
||||
TEST_INDEX = [TEST_INDEX[i] for i in top_10_idx]
|
||||
|
||||
########## load images ##############
|
||||
images, image_paths = {s: [] for s in SERIES}, {}
|
||||
segmentations = []
|
||||
print_(f"> Loading images into RAM...")
|
||||
|
||||
for s in SERIES:
|
||||
with open(path.join(DATA_DIR, f"{s}.txt"), 'r') as f:
|
||||
image_paths[s] = [l.strip() for l in f.readlines()]
|
||||
with open(path.join(DATA_DIR, f"seg.txt"), 'r') as f:
|
||||
seg_paths = [l.strip() for l in f.readlines()]
|
||||
num_images = len(seg_paths)
|
||||
|
||||
# Read and preprocess each of the paths for each SERIES, and the segmentations.
|
||||
for img_idx in TEST_INDEX[:5]: #for less images
|
||||
img_s = {s: sitk.ReadImage(image_paths[s][img_idx], sitk.sitkFloat32)
|
||||
for s in SERIES}
|
||||
seg_s = sitk.ReadImage(seg_paths[img_idx], sitk.sitkFloat32)
|
||||
img_n, seg_n = preprocess(img_s, seg_s,
|
||||
shape=IMAGE_SHAPE, spacing=TARGET_SPACING)
|
||||
for seq in img_n:
|
||||
images[seq].append(img_n[seq])
|
||||
segmentations.append(seg_n)
|
||||
|
||||
images_list = [images[s] for s in images.keys()]
|
||||
images_list = np.transpose(images_list, (1, 2, 3, 4, 0))
|
||||
|
||||
########### load module ##################
|
||||
dependencies = {
|
||||
'dice_coef': dice_coef
|
||||
}
|
||||
reconstructed_model = load_model(MODEL_PATH, custom_objects=dependencies)
|
||||
|
||||
# reconstructed_model.layers[-1].activation = tf.keras.activations.linear
|
||||
|
||||
print('START prediction')
|
||||
|
||||
ig = IntegratedGradients(reconstructed_model)
|
||||
saliency_map = []
|
||||
for img_idx in range(len(images_list)):
|
||||
# input_img = np.resize(images_list[img_idx],(1,48,48,8,8))
|
||||
input_img = np.resize(images_list[img_idx],(1,192,192,24,len(SERIES)))
|
||||
saliency_map.append(ig.get_mask(input_img).numpy())
|
||||
print("size saliency map is:",np.shape(saliency_map))
|
||||
|
||||
np.save('saliency',saliency_map)
|
||||
|
||||
# Christian Roest, [11-3-2022 15:30]
|
||||
# input_img heeft dimensies (1, 48, 48, 8, 8)
|
||||
|
||||
# reconstructed_model.summary(line_length=120)
|
||||
|
||||
# make predictions on all val_indx
|
||||
print('START saliency')
|
||||
# predictions_blur = reconstructed_model.predict(images_list, batch_size=1)
|
||||
|
|
@ -0,0 +1,90 @@
|
|||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib.cm as cm
|
||||
|
||||
heatmap = np.load('saliency.npy')
|
||||
print(np.shape(heatmap))
|
||||
heatmap = np.squeeze(heatmap)
|
||||
print(np.shape(heatmap))
|
||||
|
||||
|
||||
### take average over 5 #########
|
||||
heatmap = np.mean(abs(heatmap),axis=0)
|
||||
print(np.shape(heatmap))
|
||||
|
||||
SERIES = ['t2','b50','b400','b800','b1400','adc']
|
||||
fig, axes = plt.subplots(1,6)
|
||||
max_value = np.amax(heatmap)
|
||||
pri
|
||||
min_value = np.amin(heatmap)
|
||||
# vmin vmax van hele heatmap voor scaling in imshow
|
||||
# cmap naar grey
|
||||
|
||||
im = axes[0].imshow(np.squeeze(heatmap[:,:,12,0]))
|
||||
axes[1].imshow(np.squeeze(heatmap[:,:,12,1]), vmin=min_value, vmax=max_value)
|
||||
axes[2].imshow(np.squeeze(heatmap[:,:,12,2]), vmin=min_value, vmax=max_value)
|
||||
axes[3].imshow(np.squeeze(heatmap[:,:,12,3]), vmin=min_value, vmax=max_value)
|
||||
axes[4].imshow(np.squeeze(heatmap[:,:,12,4]), vmin=min_value, vmax=max_value)
|
||||
axes[5].imshow(np.squeeze(heatmap[:,:,12,5]), vmin=min_value, vmax=max_value)
|
||||
|
||||
axes[0].set_title("t2")
|
||||
axes[1].set_title("b50")
|
||||
axes[2].set_title("b400")
|
||||
axes[3].set_title("b800")
|
||||
axes[4].set_title("b1400")
|
||||
axes[5].set_title("adc")
|
||||
|
||||
cbar = fig.colorbar(im, ax=axes.ravel().tolist(), shrink=0.5, orientation='horizontal')
|
||||
cbar.set_ticks([-0.1,0,0.1])
|
||||
cbar.set_ticklabels(['less importance', '0', 'important'])
|
||||
fig.suptitle('Average saliency maps over the 5 highest predictions', fontsize=16)
|
||||
plt.show()
|
||||
|
||||
quit()
|
||||
|
||||
#take one image out
|
||||
heatmap = np.squeeze(heatmap[0])
|
||||
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
# Fixing random state for reproducibility
|
||||
np.random.seed(19680801)
|
||||
|
||||
|
||||
class IndexTracker:
|
||||
def __init__(self, ax, X):
|
||||
self.ax = ax
|
||||
ax.set_title('use scroll wheel to navigate images')
|
||||
|
||||
self.X = X
|
||||
rows, cols, self.slices = X.shape
|
||||
self.ind = self.slices//2
|
||||
|
||||
self.im = ax.imshow(self.X[:, :, self.ind], cmap='jet')
|
||||
self.update()
|
||||
|
||||
def on_scroll(self, event):
|
||||
print("%s %s" % (event.button, event.step))
|
||||
if event.button == 'up':
|
||||
self.ind = (self.ind + 1) % self.slices
|
||||
else:
|
||||
self.ind = (self.ind - 1) % self.slices
|
||||
self.update()
|
||||
|
||||
def update(self):
|
||||
self.im.set_data(self.X[:, :, self.ind])
|
||||
self.ax.set_ylabel('slice %s' % self.ind)
|
||||
self.im.axes.figure.canvas.draw()
|
||||
|
||||
plt.figure(0)
|
||||
fig, ax = plt.subplots(1, 1)
|
||||
tracker = IndexTracker(ax, heatmap[:,:,:,5])
|
||||
fig.canvas.mpl_connect('scroll_event', tracker.on_scroll)
|
||||
plt.show()
|
||||
|
||||
plt.figure(1)
|
||||
fig, ax = plt.subplots(1, 1)
|
||||
tracker = IndexTracker(ax, heatmap[:,:,:,3])
|
||||
fig.canvas.mpl_connect('scroll_event', tracker.on_scroll)
|
||||
plt.show()
|
|
@ -0,0 +1,59 @@
|
|||
import matplotlib.pyplot as plt
|
||||
import pandas as pd
|
||||
import glob
|
||||
import argparse
|
||||
import os
|
||||
|
||||
#create parser
|
||||
def parse_input_args():
|
||||
parser = argparse.ArgumentParser(description='Parse arguments for training a Reconstruction model')
|
||||
|
||||
parser.add_argument('train_out_dir',
|
||||
type=str,
|
||||
help='Directory name in train_output dir of the desired experiment folder. There should be a .csv file in this directory with train statistics.')
|
||||
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
args = parse_input_args()
|
||||
|
||||
print(f"Plotting {args}")
|
||||
|
||||
# find csv file
|
||||
# csv = glob.glob(f"train_output/{args.train_out_dir}/*.csv")[0]
|
||||
folder_input = args.train_out_dir
|
||||
|
||||
# load csv file
|
||||
df = pd.read_csv(f'{folder_input}')
|
||||
|
||||
# read csv file
|
||||
for metric in df:
|
||||
# if not metric == 'epoch':
|
||||
if metric == 'loss' or metric == 'val_loss':
|
||||
plt.plot(df['epoch'], df[metric], label=metric)
|
||||
plt.ylim(ymin=0,ymax=0.01)
|
||||
|
||||
|
||||
|
||||
folder, csvfile = os.path.split(args.train_out_dir)
|
||||
root, experiment = os.path.split(os.path.split(folder)[0])
|
||||
|
||||
plt.title(experiment)
|
||||
plt.xlabel('Epoch')
|
||||
plt.ylabel('Loss')
|
||||
plt.grid()
|
||||
plt.legend()
|
||||
plt.savefig(f"{folder}/{experiment}.png")
|
||||
plt.clf()
|
||||
plt.close()
|
||||
print(folder+".png")
|
||||
print(f"\nsaved figure to {folder}")
|
||||
|
||||
# van yaml inladen 'loss'
|
||||
# vanuit utils > dict to yaml
|
||||
# csv viewer extension
|
||||
# course .venv/bin/activate
|
||||
# loop over alles en if metric then 1-metric
|
||||
# kleur codering color=[]
|
||||
|
||||
|
Loading…
Reference in New Issue