opschonen van 2.froc.py

This commit is contained in:
Stefan 2022-03-21 12:25:15 +01:00
parent be5b392456
commit 02d5b371d6
2 changed files with 55 additions and 143 deletions

View File

@ -1,63 +1,56 @@
import sys
from os import path
import SimpleITK as sitk import SimpleITK as sitk
import tensorflow as tf import tensorflow as tf
from tensorflow.keras.models import load_model from tensorflow.keras.models import load_model
from focal_loss import BinaryFocalLoss from focal_loss import BinaryFocalLoss
import json
import matplotlib.pyplot as plt
import numpy as np import numpy as np
import multiprocessing import multiprocessing
from functools import partial from functools import partial
import os
sys.path.append('./../code') from sfransen.utils_quintin import *
from utils_quintin import * from sfransen.DWI_exp.helpers import *
from sfransen.DWI_exp.preprocessing_function import preprocess
sys.path.append('./../code/DWI_exp') from sfransen.DWI_exp.callbacks import dice_coef
from helpers import * from sfransen.FROC.blob_preprocess import *
from preprocessing_function import preprocess from sfransen.FROC.cal_froc_from_np import *
from callbacks import dice_coef from sfransen.load_images import load_images_parrallel
sys.path.append('./../code/FROC')
from blob_preprocess import *
from cal_froc_from_np import *
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description='Train a U-Net model for segmentation/detection tasks.' + description='Calculate the froc metrics and store in froc_metrics.yml')
'using cross-validation.') parser.add_argument('-experiment',
help='Title of experiment')
parser.add_argument('--series', '-s', parser.add_argument('--series', '-s',
metavar='[series_name]', required=True, nargs='+', metavar='[series_name]', required=True, nargs='+',
help='List of series to include, must correspond with' + help='List of series to include')
"path files in ./data/")
args = parser.parse_args() args = parser.parse_args()
######## parsed inputs ############# ######## CUDA ################
# SERIES = ['b50', 'b400', 'b800'] #can be parsed os.environ["CUDA_VISIBLE_DEVICES"] = "2"
######## constants #############
SERIES = args.series SERIES = args.series
series_ = '_'.join(args.series) series_ = '_'.join(args.series)
# Import model EXPERIMENT = args.experiment
# MODEL_PATH = f'./../train_output/train_10h_{series_}/models/train_10h_{series_}.h5'
# YAML_DIR = f'./../train_output/train_10h_{series_}' MODEL_PATH = f'./../train_output/{EXPERIMENT}_{series_}/models/{EXPERIMENT}_{series_}.h5'
MODEL_PATH = f'./../train_output/train_n0.001_{series_}/models/train_n0.001_{series_}.h5' YAML_DIR = f'./../train_output/{EXPERIMENT}_{series_}'
print(MODEL_PATH)
YAML_DIR = f'./../train_output/train_n0.001_{series_}'
################ constants ############
DATA_DIR = "./../data/Nijmegen paths/" DATA_DIR = "./../data/Nijmegen paths/"
TARGET_SPACING = (0.5, 0.5, 3) TARGET_SPACING = (0.5, 0.5, 3)
INPUT_SHAPE = (192, 192, 24, len(SERIES)) INPUT_SHAPE = (192, 192, 24, len(SERIES))
IMAGE_SHAPE = INPUT_SHAPE[:3] IMAGE_SHAPE = INPUT_SHAPE[:3]
# import val_indx
DATA_SPLIT_INDEX = read_yaml_to_dict('./../data/Nijmegen paths/train_val_test_idxs.yml') DATA_SPLIT_INDEX = read_yaml_to_dict('./../data/Nijmegen paths/train_val_test_idxs.yml')
TEST_INDEX = DATA_SPLIT_INDEX['val_set0'] TEST_INDEX = DATA_SPLIT_INDEX['val_set0']
N_CPUS = 12
########## load images ############## ########## load images in parrallel ##############
images, image_paths = {s: [] for s in SERIES}, {}
segmentations = []
print_(f"> Loading images into RAM...") print_(f"> Loading images into RAM...")
# read paths from txt
image_paths = {}
for s in SERIES: for s in SERIES:
with open(path.join(DATA_DIR, f"{s}.txt"), 'r') as f: with open(path.join(DATA_DIR, f"{s}.txt"), 'r') as f:
image_paths[s] = [l.strip() for l in f.readlines()] image_paths[s] = [l.strip() for l in f.readlines()]
@ -65,90 +58,33 @@ with open(path.join(DATA_DIR, f"seg.txt"), 'r') as f:
seg_paths = [l.strip() for l in f.readlines()] seg_paths = [l.strip() for l in f.readlines()]
num_images = len(seg_paths) num_images = len(seg_paths)
# Read and preprocess each of the paths for each SERIES, and the segmentations. # create pool of workers
from typing import List
def load_images(
image_paths: str,
seq: str,
target_shape: List[int],
target_space = List[float]):
img_s = sitk.ReadImage(image_paths, sitk.sitkFloat32)
#resample
mri_tra_s = resample(img_s,
min_shape=target_shape,
method=sitk.sitkNearestNeighbor,
new_spacing=target_space)
#center crop
mri_tra_s = center_crop(mri_tra_s, shape=target_shape)
#normalize
if seq != 'seg':
filter = sitk.NormalizeImageFilter()
mri_tra_s = filter.Execute(mri_tra_s)
else:
filter = sitk.BinaryThresholdImageFilter()
filter.SetLowerThreshold(1.0)
mri_tra_s = filter.Execute(mri_tra_s)
return sitk.GetArrayFromImage(mri_tra_s).T
N_CPUS = 12
pool = multiprocessing.Pool(processes=N_CPUS) pool = multiprocessing.Pool(processes=N_CPUS)
partial_f = partial(load_images, partial_images = partial(load_images_parrallel,
seq = 'images', seq = 'images',
target_shape=IMAGE_SHAPE, target_shape=IMAGE_SHAPE,
target_space = TARGET_SPACING) target_space = TARGET_SPACING)
partial_seg = partial(load_images_parrallel,
images_2 = []
for s in SERIES:
image_paths_seq = image_paths[s]
image_paths_index = np.asarray(image_paths_seq)[TEST_INDEX]
data_list = pool.map(partial_f,image_paths_index)
data = np.stack(data_list, axis=0)
images_2.append(data)
# print(s)
# print(np.shape(data))
print(np.shape(images_2))
partial_f = partial(load_images,
seq = 'seg', seq = 'seg',
target_shape=IMAGE_SHAPE, target_shape=IMAGE_SHAPE,
target_space = TARGET_SPACING) target_space = TARGET_SPACING)
#load images
images = []
for s in SERIES:
image_paths_seq = image_paths[s]
image_paths_index = np.asarray(image_paths_seq)[TEST_INDEX]
data_list = pool.map(partial_images,image_paths_index)
data = np.stack(data_list, axis=0)
images.append(data)
images_list = np.transpose(images, (1, 2, 3, 4, 0))
#load segmentations
seg_paths_index = np.asarray(seg_paths)[TEST_INDEX] seg_paths_index = np.asarray(seg_paths)[TEST_INDEX]
data_list = pool.map(partial_f,seg_paths_index) data_list = pool.map(partial_seg,seg_paths_index)
segmentations = np.stack(data_list, axis=0) segmentations = np.stack(data_list, axis=0)
# print("segmentations pool",np.shape(segmentations_2))
# for img_idx in TEST_INDEX: #for less images
# img_s = {s: sitk.ReadImage(image_paths[s][img_idx], sitk.sitkFloat32)
# for s in SERIES}
# seg_s = sitk.ReadImage(seg_paths[img_idx], sitk.sitkFloat32)
# img_n, seg_n = preprocess(img_s, seg_s,
# shape=IMAGE_SHAPE, spacing=TARGET_SPACING)
# for seq in img_n:
# images[seq].append(img_n[seq])
# segmentations.append(seg_n)
# print("segmentations old",np.shape(segmentations))
# # from dict to list
# # images_list = [img nmbr, [INPUT_SHAPE]]
# images_list = [images[s] for s in images.keys()]
# images_list = np.transpose(images_list, (1, 2, 3, 4, 0))
images_list = np.transpose(images_2, (1, 2, 3, 4, 0))
print("images size ",np.shape(images_list))
print("size segmentation",np.shape(segmentations))
# print("images size pool",np.shape(images_list_2))
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "2"
########### load module ################## ########### load module ##################
print(' >>>>>>> LOAD MODEL <<<<<<<<<') print(' >>>>>>> LOAD MODEL <<<<<<<<<')
@ -158,21 +94,18 @@ dependencies = {
reconstructed_model = load_model(MODEL_PATH, custom_objects=dependencies) reconstructed_model = load_model(MODEL_PATH, custom_objects=dependencies)
# reconstructed_model.summary(line_length=120) # reconstructed_model.summary(line_length=120)
# make predictions on all val_indx # make predictions on all TEST_INDEX
print(' >>>>>>> START prediction <<<<<<<<<') print(' >>>>>>> START prediction <<<<<<<<<')
predictions_blur = reconstructed_model.predict(images_list, batch_size=1) predictions_blur = reconstructed_model.predict(images_list, batch_size=1)
# print("The shape of the predictions list is: ",np.shape(predictions_blur)) ############# preprocess #################
# print(type(predictions))
# np.save('predictions',predictions)
# preprocess predictions by removing the blur and making individual blobs # preprocess predictions by removing the blur and making individual blobs
print('>>>>>>>> START preprocess') print('>>>>>>>> START preprocess')
def move_dims(arr): def move_dims(arr):
# UMCG numpy dimensions convention: dims = (batch, width, heigth, depth) # UMCG numpy dimensions convention: dims = (batch, width, heigth, depth)
# Joeran numpy dimensions convention: dims = (batch, depth, heigth, width) # Joeran numpy dimensions convention: dims = (batch, depth, heigth, width)
arr = np.moveaxis(arr, 3, 1) # Joeran has his numpy arrays ordered differently. arr = np.moveaxis(arr, 3, 1)
arr = np.moveaxis(arr, 3, 2) arr = np.moveaxis(arr, 3, 2)
return arr return arr
@ -192,8 +125,8 @@ metrics = evaluate(y_true=segmentations, y_pred=predictions)
dump_dict_to_yaml(metrics, YAML_DIR, "froc_metrics", verbose=True) dump_dict_to_yaml(metrics, YAML_DIR, "froc_metrics", verbose=True)
############## save image as example #################
# save one image # save image nmr 3
IMAGE_DIR = f'./../train_output/train_10h_{series_}' IMAGE_DIR = f'./../train_output/train_10h_{series_}'
img_s = sitk.GetImageFromArray(predictions_blur[3].squeeze()) img_s = sitk.GetImageFromArray(predictions_blur[3].squeeze())
sitk.WriteImage(img_s, f"{IMAGE_DIR}/predictions_blur_001.nii.gz") sitk.WriteImage(img_s, f"{IMAGE_DIR}/predictions_blur_001.nii.gz")
@ -203,15 +136,3 @@ sitk.WriteImage(img_s, f"{IMAGE_DIR}/predictions_001.nii.gz")
img_s = sitk.GetImageFromArray(segmentations[3].squeeze()) img_s = sitk.GetImageFromArray(segmentations[3].squeeze())
sitk.WriteImage(img_s, f"{IMAGE_DIR}/segmentations_001.nii.gz") sitk.WriteImage(img_s, f"{IMAGE_DIR}/segmentations_001.nii.gz")
# create plot
# json_path = './../scripts/metrics.json'
# f = open(json_path)
# data = json.load(f)
# x = data['fpr']
# y = data['tpr']
# auroc = data['auroc']
# plt.plot(x,y)

View File

@ -11,24 +11,16 @@ import numpy as np
from sfransen.Saliency.base import * from sfransen.Saliency.base import *
from sfransen.Saliency.integrated_gradients import * from sfransen.Saliency.integrated_gradients import *
# from tensorflow.keras.vis.visualization import visualize_saliency from sfransen.utils_quintin import *
sys.path.append('./../code')
from utils_quintin import *
sys.path.append('./../code/DWI_exp')
# from preprocessing_function import preprocess
from sfransen.DWI_exp import preprocess from sfransen.DWI_exp import preprocess
print("done step 1")
from sfransen.DWI_exp.helpers import * from sfransen.DWI_exp.helpers import *
# from helpers import * from sfransen.DWI_exp.callbacks import dice_coef
from callbacks import dice_coef from sfransen.FROC.blob_preprocess import *
from sfransen.FROC.cal_froc_from_np import *
sys.path.append('./../code/FROC')
from blob_preprocess import *
from cal_froc_from_np import *
quit()
# train_10h_t2_b50_b400_b800_b1400_adc # train_10h_t2_b50_b400_b800_b1400_adc
SERIES = ['t2','b50','b400','b800','b1400','adc'] SERIES = ['t2','b50','b400','b800','b1400','adc']
MODEL_PATH = f'./../train_output/train_10h_t2_b50_b400_b800_b1400_adc/models/train_10h_t2_b50_b400_b800_b1400_adc.h5' MODEL_PATH = f'./../train_output/train_10h_t2_b50_b400_b800_b1400_adc/models/train_10h_t2_b50_b400_b800_b1400_adc.h5'
@ -46,7 +38,7 @@ IMAGE_SHAPE = INPUT_SHAPE[:3]
experiment_path = f'./../train_output/train_10h_t2_b50_b400_b800_b1400_adc/froc_metrics.yml' experiment_path = f'./../train_output/train_10h_t2_b50_b400_b800_b1400_adc/froc_metrics.yml'
experiment_metrics = read_yaml_to_dict(experiment_path) experiment_metrics = read_yaml_to_dict(experiment_path)
DATA_SPLIT_INDEX = read_yaml_to_dict('./../data/Nijmegen paths/train_val_test_idxs.yml') DATA_SPLIT_INDEX = read_yaml_to_dict('./../data/Nijmegen paths/train_val_test_idxs.yml')
TEST_INDEX = DATA_SPLIT_INDEX['val_set0'] TEST_INDEX = DATA_SPLIT_INDEX['test_set0']
top_10_idx = np.argsort(experiment_metrics['roc_pred'])[-10:] top_10_idx = np.argsort(experiment_metrics['roc_pred'])[-10:]
TEST_INDEX = [TEST_INDEX[i] for i in top_10_idx] TEST_INDEX = [TEST_INDEX[i] for i in top_10_idx]
@ -90,7 +82,6 @@ print('START prediction')
ig = IntegratedGradients(reconstructed_model) ig = IntegratedGradients(reconstructed_model)
saliency_map = [] saliency_map = []
for img_idx in range(len(images_list)): for img_idx in range(len(images_list)):
# input_img = np.resize(images_list[img_idx],(1,48,48,8,8))
input_img = np.resize(images_list[img_idx],(1,192,192,24,len(SERIES))) input_img = np.resize(images_list[img_idx],(1,192,192,24,len(SERIES)))
saliency_map.append(ig.get_mask(input_img).numpy()) saliency_map.append(ig.get_mask(input_img).numpy())
print("size saliency map is:",np.shape(saliency_map)) print("size saliency map is:",np.shape(saliency_map))