import sys from os import path import SimpleITK as sitk import tensorflow as tf from tensorflow.keras.models import load_model from focal_loss import BinaryFocalLoss import json import matplotlib.pyplot as plt import numpy as np import multiprocessing from functools import partial sys.path.append('./../code') from utils_quintin import * sys.path.append('./../code/DWI_exp') from helpers import * from preprocessing_function import preprocess from callbacks import dice_coef sys.path.append('./../code/FROC') from blob_preprocess import * from cal_froc_from_np import * parser = argparse.ArgumentParser( description='Train a U-Net model for segmentation/detection tasks.' + 'using cross-validation.') parser.add_argument('--series', '-s', metavar='[series_name]', required=True, nargs='+', help='List of series to include, must correspond with' + "path files in ./data/") args = parser.parse_args() ######## parsed inputs ############# # SERIES = ['b50', 'b400', 'b800'] #can be parsed SERIES = args.series series_ = '_'.join(args.series) # Import model # MODEL_PATH = f'./../train_output/train_10h_{series_}/models/train_10h_{series_}.h5' # YAML_DIR = f'./../train_output/train_10h_{series_}' MODEL_PATH = f'./../train_output/train_n0.001_{series_}/models/train_n0.001_{series_}.h5' print(MODEL_PATH) YAML_DIR = f'./../train_output/train_n0.001_{series_}' ################ constants ############ DATA_DIR = "./../data/Nijmegen paths/" TARGET_SPACING = (0.5, 0.5, 3) INPUT_SHAPE = (192, 192, 24, len(SERIES)) IMAGE_SHAPE = INPUT_SHAPE[:3] # import val_indx DATA_SPLIT_INDEX = read_yaml_to_dict('./../data/Nijmegen paths/train_val_test_idxs.yml') TEST_INDEX = DATA_SPLIT_INDEX['val_set0'] ########## load images ############## images, image_paths = {s: [] for s in SERIES}, {} segmentations = [] print_(f"> Loading images into RAM...") for s in SERIES: with open(path.join(DATA_DIR, f"{s}.txt"), 'r') as f: image_paths[s] = [l.strip() for l in f.readlines()] with open(path.join(DATA_DIR, f"seg.txt"), 'r') as f: seg_paths = [l.strip() for l in f.readlines()] num_images = len(seg_paths) # Read and preprocess each of the paths for each SERIES, and the segmentations. from typing import List def load_images( image_paths: str, seq: str, target_shape: List[int], target_space = List[float]): img_s = sitk.ReadImage(image_paths, sitk.sitkFloat32) #resample mri_tra_s = resample(img_s, min_shape=target_shape, method=sitk.sitkNearestNeighbor, new_spacing=target_space) #center crop mri_tra_s = center_crop(mri_tra_s, shape=target_shape) #normalize if seq != 'seg': filter = sitk.NormalizeImageFilter() mri_tra_s = filter.Execute(mri_tra_s) else: filter = sitk.BinaryThresholdImageFilter() filter.SetLowerThreshold(1.0) mri_tra_s = filter.Execute(mri_tra_s) return sitk.GetArrayFromImage(mri_tra_s).T N_CPUS = 12 pool = multiprocessing.Pool(processes=N_CPUS) partial_f = partial(load_images, seq = 'images', target_shape=IMAGE_SHAPE, target_space = TARGET_SPACING) images_2 = [] for s in SERIES: image_paths_seq = image_paths[s] image_paths_index = np.asarray(image_paths_seq)[TEST_INDEX] data_list = pool.map(partial_f,image_paths_index) data = np.stack(data_list, axis=0) images_2.append(data) # print(s) # print(np.shape(data)) print(np.shape(images_2)) partial_f = partial(load_images, seq = 'seg', target_shape=IMAGE_SHAPE, target_space = TARGET_SPACING) seg_paths_index = np.asarray(seg_paths)[TEST_INDEX] data_list = pool.map(partial_f,seg_paths_index) segmentations = np.stack(data_list, axis=0) # print("segmentations pool",np.shape(segmentations_2)) # for img_idx in TEST_INDEX: #for less images # img_s = {s: sitk.ReadImage(image_paths[s][img_idx], sitk.sitkFloat32) # for s in SERIES} # seg_s = sitk.ReadImage(seg_paths[img_idx], sitk.sitkFloat32) # img_n, seg_n = preprocess(img_s, seg_s, # shape=IMAGE_SHAPE, spacing=TARGET_SPACING) # for seq in img_n: # images[seq].append(img_n[seq]) # segmentations.append(seg_n) # print("segmentations old",np.shape(segmentations)) # # from dict to list # # images_list = [img nmbr, [INPUT_SHAPE]] # images_list = [images[s] for s in images.keys()] # images_list = np.transpose(images_list, (1, 2, 3, 4, 0)) images_list = np.transpose(images_2, (1, 2, 3, 4, 0)) print("images size ",np.shape(images_list)) print("size segmentation",np.shape(segmentations)) # print("images size pool",np.shape(images_list_2)) import os os.environ["CUDA_VISIBLE_DEVICES"] = "2" ########### load module ################## print(' >>>>>>> LOAD MODEL <<<<<<<<<') dependencies = { 'dice_coef': dice_coef } reconstructed_model = load_model(MODEL_PATH, custom_objects=dependencies) # reconstructed_model.summary(line_length=120) # make predictions on all val_indx print(' >>>>>>> START prediction <<<<<<<<<') predictions_blur = reconstructed_model.predict(images_list, batch_size=1) # print("The shape of the predictions list is: ",np.shape(predictions_blur)) # print(type(predictions)) # np.save('predictions',predictions) # preprocess predictions by removing the blur and making individual blobs print('>>>>>>>> START preprocess') def move_dims(arr): # UMCG numpy dimensions convention: dims = (batch, width, heigth, depth) # Joeran numpy dimensions convention: dims = (batch, depth, heigth, width) arr = np.moveaxis(arr, 3, 1) # Joeran has his numpy arrays ordered differently. arr = np.moveaxis(arr, 3, 2) return arr # Joeran has his numpy arrays ordered differently. predictions_blur = move_dims(np.squeeze(predictions_blur)) segmentations = move_dims(np.squeeze(segmentations)) predictions = [preprocess_softmax(pred, threshold="dynamic")[0] for pred in predictions_blur] # Remove outer edges zeros = np.zeros(np.shape(predictions)) test = np.squeeze(predictions)[:,:,2:190,2:190] zeros[:,:,2:190,2:190] = test predictions = zeros # perform Froc metrics = evaluate(y_true=segmentations, y_pred=predictions) dump_dict_to_yaml(metrics, YAML_DIR, "froc_metrics", verbose=True) # save one image IMAGE_DIR = f'./../train_output/train_10h_{series_}' img_s = sitk.GetImageFromArray(predictions_blur[3].squeeze()) sitk.WriteImage(img_s, f"{IMAGE_DIR}/predictions_blur_001.nii.gz") img_s = sitk.GetImageFromArray(predictions[3].squeeze()) sitk.WriteImage(img_s, f"{IMAGE_DIR}/predictions_001.nii.gz") img_s = sitk.GetImageFromArray(segmentations[3].squeeze()) sitk.WriteImage(img_s, f"{IMAGE_DIR}/segmentations_001.nii.gz") # create plot # json_path = './../scripts/metrics.json' # f = open(json_path) # data = json.load(f) # x = data['fpr'] # y = data['tpr'] # auroc = data['auroc'] # plt.plot(x,y)