diff --git a/scripts/1.U-net_chris.py b/scripts/1.U-net_chris.py index a7df8b7..0243dbe 100755 --- a/scripts/1.U-net_chris.py +++ b/scripts/1.U-net_chris.py @@ -61,14 +61,14 @@ OUTPUT_SHAPE = (192, 192, 24, 1) # One output channel (segmentation) # Hyperparameters FOCAL_LOSS_GAMMA = 2 INITIAL_LEARNING_RATE = 1e-4 -MAX_EPOCHS = 600 +MAX_EPOCHS = 1500 EARLY_STOPPING = 50 # increase batch size BATCH_SIZE = 12 MODEL_SELECTION_METRIC = 'val_loss' MODEL_SELECTION_DIRECTION = "min" # Change to 'max' if higher value is better EARLY_STOPPING_METRIC = 'val_loss' -EARLY_STOPPING_DIRECTION = "min" # Change to 'max' if higher value is better +EARLY_STOPPING_DIRECTION = 'min' # Training configuration # add metric ROC_AUC @@ -170,19 +170,19 @@ callbacks = [ monitor=EARLY_STOPPING_METRIC, mode=EARLY_STOPPING_DIRECTION, patience=EARLY_STOPPING, - verbose=1), + verbose=2), ModelCheckpoint( filepath=path.join(PROJECT_DIR, "models", JOB_NAME + ".h5"), monitor=MODEL_SELECTION_METRIC, mode=MODEL_SELECTION_DIRECTION, - verbose=1, + verbose=2, + save_best_only=True), + ModelCheckpoint( + filepath=path.join(PROJECT_DIR, "models", JOB_NAME + "_dice" + ".h5"), + monitor='val_dice_coef', + mode='max', + verbose=2, save_best_only=True), - # ModelCheckpoint( - # filepath=path.join(PROJECT_DIR, "models_dice", JOB_NAME + ".h5"), - # monitor='val_dice_coef', - # mode='max', - # verbose=0, - # save_best_only=True), CSVLogger( filename=path.join(PROJECT_DIR, "logs", f"{JOB_NAME}.csv")), IntermediateImages( @@ -200,5 +200,8 @@ detection_model.fit(train_generator, steps_per_epoch = len(train_idxs) // BATCH_SIZE, epochs = MAX_EPOCHS, callbacks = callbacks, - verbose = 1, + verbose = 2 + + + , ) diff --git a/scripts/4.frocs.py b/scripts/4.frocs.py index 662d44d..c1053db 100755 --- a/scripts/4.frocs.py +++ b/scripts/4.frocs.py @@ -11,7 +11,7 @@ from sfransen.utils_quintin import * from sfransen.DWI_exp.helpers import * from sfransen.DWI_exp.preprocessing_function import preprocess from sfransen.DWI_exp.callbacks import dice_coef -from sfransen.FROC.blob_preprocess import * +#from sfransen.FROC.blob_preprocess import * from sfransen.FROC.cal_froc_from_np import * from sfransen.load_images import load_images_parrallel @@ -25,6 +25,9 @@ parser.add_argument('--series', '-s', help='List of series to include') args = parser.parse_args() +# if __name__ = '__main__': +# bovenstaande nodig om fork probleem op te lossen (windows cs linux) + ######## CUDA ################ os.environ["CUDA_VISIBLE_DEVICES"] = "2" @@ -33,7 +36,7 @@ SERIES = args.series series_ = '_'.join(args.series) EXPERIMENT = args.experiment -MODEL_PATH = f'./../train_output/{EXPERIMENT}_{series_}/models/{EXPERIMENT}_{series_}.h5' +MODEL_PATH = f'./../train_output/{EXPERIMENT}_{series_}/models/{EXPERIMENT}_{series_}_dice.h5' YAML_DIR = f'./../train_output/{EXPERIMENT}_{series_}' IMAGE_DIR = f'./../train_output/{EXPERIMENT}_{series_}' diff --git a/scripts/5.Visualize_frocs.py b/scripts/5.Visualize_frocs.py index 959d654..a48b2c2 100755 --- a/scripts/5.Visualize_frocs.py +++ b/scripts/5.Visualize_frocs.py @@ -1,3 +1,4 @@ +from pickle import TRUE from sfransen.utils_quintin import * import matplotlib.pyplot as plt import argparse @@ -15,11 +16,11 @@ parser.add_argument('--experiment', '-s', args = parser.parse_args() if args.comparison: - colors = ['r','r','b','b','g','g'] - plot_type = ['-','--','-','--','-','--'] + colors = ['r','r','b','b','g','g','y','y'] + plot_type = ['-','--','-','--','-','--','-','--'] else: - colors = ['r','b','g','k'] - plot_type = ['-','-','-','-'] + colors = ['r','b','g','k','y','c'] + plot_type = ['-','-','-','-','-','-'] experiments = args.experiment print(experiments) @@ -48,10 +49,10 @@ plt.title('fROC curve') plt.xlabel('False positive per case') plt.ylabel('Sensitivity') plt.legend(experiments,loc='lower right') -plt.xlim([0,3]) +# plt.xlim([0,50]) +plt.grid() plt.ylim([0,1]) plt.yticks([0,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1]) -plt.grid() plt.savefig(f"./../train_output/fROC_{args.saveas}.png", dpi=300) concat_func = lambda x,y: x + " (" + str(y) + ")" diff --git a/scripts/6.saliency_map.py b/scripts/6.saliency_map.py index 1ded848..40265be 100755 --- a/scripts/6.saliency_map.py +++ b/scripts/6.saliency_map.py @@ -110,4 +110,5 @@ for img_idx in range(len(images_list)): print("size saliency map",np.shape(saliency_map)) np.save(f'{YAML_DIR}/saliency',saliency_map) np.save(f'{YAML_DIR}/images_list',images_list) -np.save(f'{YAML_DIR}/segmentations',segmentations) \ No newline at end of file +np.save(f'{YAML_DIR}/segmentations',segmentations) + diff --git a/scripts/7.Visualize_saliency.py b/scripts/7.Visualize_saliency.py index b19a10b..d68214c 100755 --- a/scripts/7.Visualize_saliency.py +++ b/scripts/7.Visualize_saliency.py @@ -1,6 +1,7 @@ import argparse import numpy as np import matplotlib.pyplot as plt +import SimpleITK as sitk parser = argparse.ArgumentParser( description='Calculate the froc metrics and store in froc_metrics.yml') @@ -53,5 +54,4 @@ cbar = fig.colorbar(im, ax=axes.ravel().tolist(), shrink=0.5, orientation='horiz cbar.set_ticks([min_value,max_value]) cbar.set_ticklabels(['less important', 'important']) fig.suptitle('Saliency map', fontsize=16) -plt.savefig(f'./../train_output/{EXPERIMENT}_{series_}/saliency_map.png', dpi=300) - +plt.savefig(f'./../train_output/{EXPERIMENT}_{series_}/saliency_map.png', dpi=300) \ No newline at end of file diff --git a/scripts/9.save_image.py b/scripts/9.save_image.py new file mode 100755 index 0000000..8567f36 --- /dev/null +++ b/scripts/9.save_image.py @@ -0,0 +1,136 @@ +import argparse +from os import path +import SimpleITK as sitk +import tensorflow as tf +from tensorflow.keras.models import load_model +import numpy as np + +from sfransen.utils_quintin import * +from sfransen.DWI_exp import preprocess +from sfransen.DWI_exp.helpers import * +from sfransen.DWI_exp.callbacks import dice_coef +from sfransen.FROC.blob_preprocess import * +from sfransen.FROC.cal_froc_from_np import * +from sfransen.load_images import load_images_parrallel +from sfransen.Saliency.base import * +from sfransen.Saliency.integrated_gradients import * + +parser = argparse.ArgumentParser( + description='Calculate the froc metrics and store in froc_metrics.yml') +parser.add_argument('-experiment', + help='Title of experiment') +parser.add_argument('--series', '-s', + metavar='[series_name]', required=True, nargs='+', + help='List of series to include') +args = parser.parse_args() + +######## CUDA ################ +os.environ["CUDA_VISIBLE_DEVICES"] = "2" + +######## constants ############# +SERIES = args.series +series_ = '_'.join(args.series) +EXPERIMENT = args.experiment + +MODEL_PATH = f'./../train_output/{EXPERIMENT}_{series_}/models/{EXPERIMENT}_{series_}.h5' +DATA_DIR = "./../data/Nijmegen paths/" +TARGET_SPACING = (0.5, 0.5, 3) +INPUT_SHAPE = (192, 192, 24, len(SERIES)) +IMAGE_SHAPE = INPUT_SHAPE[:3] +TEST_INDEX_image = [371,12] +N_CPUS = 12 + +########## load images in parrallel ############## +print_(f"> Loading images into RAM...") + +# read paths from txt +image_paths = {} +for s in SERIES: + with open(path.join(DATA_DIR, f"{s}.txt"), 'r') as f: + image_paths[s] = [l.strip() for l in f.readlines()] +with open(path.join(DATA_DIR, f"seg.txt"), 'r') as f: + seg_paths = [l.strip() for l in f.readlines()] +num_images = len(seg_paths) + +# create pool of workers +pool = multiprocessing.Pool(processes=N_CPUS) +partial_images = partial(load_images_parrallel, + seq = 'images', + target_shape=IMAGE_SHAPE, + target_space = TARGET_SPACING) +partial_seg = partial(load_images_parrallel, + seq = 'seg', + target_shape=IMAGE_SHAPE, + target_space = TARGET_SPACING) + +#load images +images = [] +for s in SERIES: + image_paths_seq = image_paths[s] + image_paths_index = np.asarray(image_paths_seq)[TEST_INDEX_image] + data_list = pool.map(partial_images,image_paths_index) + data = np.stack(data_list, axis=0) + images.append(data) + +images_list = np.transpose(images, (1, 2, 3, 4, 0)) + +#load segmentations +seg_paths_index = np.asarray(seg_paths)[TEST_INDEX_image] +data_list = pool.map(partial_seg,seg_paths_index) +segmentations = np.stack(data_list, axis=0) + +########### load module ################## +print(' >>>>>>> LOAD MODEL <<<<<<<<<') + +dependencies = { + 'dice_coef': dice_coef +} +reconstructed_model = load_model(MODEL_PATH, custom_objects=dependencies) +# reconstructed_model.layers[-1].activation = tf.keras.activations.linear +predictions_blur = reconstructed_model.predict(images_list, batch_size=1) + +############# preprocess ################# +# preprocess predictions by removing the blur and making individual blobs +print('>>>>>>>> START preprocess') + +def move_dims(arr): + # UMCG numpy dimensions convention: dims = (batch, width, heigth, depth) + # Joeran numpy dimensions convention: dims = (batch, depth, heigth, width) + arr = np.moveaxis(arr, 3, 1) + arr = np.moveaxis(arr, 3, 2) + return arr + +# Joeran has his numpy arrays ordered differently. +print("images_list:",np.shape(images_list)) +print("predictions_blur:",np.shape(predictions_blur)) +print("segmentations:",np.shape(segmentations)) + +predictions_blur = move_dims(np.squeeze(predictions_blur)) +segmentations = move_dims(np.squeeze(segmentations)) +# images_list = move_dims(np.squeeze(images_list)) +predictions = [preprocess_softmax(pred, threshold="dynamic")[0] for pred in predictions_blur] + +print("images_list:",np.shape(images_list)) +print("predictions_blur:",np.shape(predictions_blur)) +print("segmentations:",np.shape(segmentations)) +print("predictions:",np.shape(predictions)) + +# Remove outer edges +zeros = np.zeros(np.shape(np.squeeze(predictions))) +test = np.squeeze(predictions)[:,2:-2,2:190,2:190] +zeros[:,2:-2,2:190,2:190] = test +predictions = zeros + +############## save image as example ################# +print("images_list size ",np.shape(images_list[0,:,:,:,0])) +img_s = sitk.GetImageFromArray(np.transpose(images_list[0,:,:,:,0].squeeze())) +sitk.WriteImage(img_s, f"./../train_output/{EXPERIMENT}_{series_}/t2_002.nii.gz") + +img_s = sitk.GetImageFromArray(predictions_blur[0].squeeze()) +sitk.WriteImage(img_s, f"./../train_output/{EXPERIMENT}_{series_}/predictions_blur_002.nii.gz") + +img_s = sitk.GetImageFromArray(predictions[0].squeeze()) +sitk.WriteImage(img_s, f"./../train_output/{EXPERIMENT}_{series_}/predictions_002.nii.gz") + +img_s = sitk.GetImageFromArray(segmentations[0].squeeze()) +sitk.WriteImage(img_s, f"./../train_output/{EXPERIMENT}_{series_}/segmentations_002.nii.gz") \ No newline at end of file diff --git a/scripts/calc_adc.py b/scripts/calc_adc.py new file mode 100755 index 0000000..7017274 --- /dev/null +++ b/scripts/calc_adc.py @@ -0,0 +1,132 @@ +import numpy as np +import SimpleITK as sitk +import matplotlib.pyplot as plt +######## load images ############# +# path_b50 = '/data/pca-rad/datasets/radboud_new/pat0351/2016/diffusie_cro/b-50/nifti_image.nii.gz' +# path_b400 = '/data/pca-rad/datasets/radboud_new/pat0351/2016/diffusie_cro/b-400/nifti_image.nii.gz' +# path_b800 = '/data/pca-rad/datasets/radboud_new/pat0351/2016/diffusie_cro/b-800/nifti_image.nii.gz' +# path_b1400 = '/data/pca-rad/datasets/radboud_new/pat0351/2016/diffusie_cro/b-1400/nifti_image.nii.gz' +# path_adc = '/data/pca-rad/datasets/radboud_new/pat0351/2016/dADC/nifti_image.nii.gz' + +# path_b50 = 'X:/sfransen/train_output/adc_exp/b50.nii.gz' +# path_b400 = 'X:/sfransen/train_output/adc_exp/b400.nii.gz' +# path_b800 = 'X:/sfransen/train_output/adc_exp/b800.nii.gz' +# path_b1400 = 'X:/sfransen/train_output/adc_exp/b1400.nii.gz' +# path_adc = 'X:/sfransen/train_output/adc_exp/adc.nii.gz' + +path_b50 = '/data/pca-rad/sfransen/train_output/adc_exp/b50_true.nii.gz' +path_b400 = '/data/pca-rad/sfransen/train_output/adc_exp/b400_true.nii.gz' +path_b800 = '/data/pca-rad/sfransen/train_output/adc_exp/b800_true.nii.gz' +path_b1400 = '/data/pca-rad/sfransen/train_output/adc_exp/b1400_true.nii.gz' +path_adc = '/data/pca-rad/sfransen/train_output/adc_exp/adc_calc_b50_b400_b800.nii.gz' + +b50 = sitk.ReadImage(path_b50, sitk.sitkFloat32) +b50 = sitk.GetArrayFromImage(b50) +b400 = sitk.ReadImage(path_b400, sitk.sitkFloat32) +b400 = sitk.GetArrayFromImage(b400) +b800 = sitk.ReadImage(path_b800, sitk.sitkFloat32) +b800 = sitk.GetArrayFromImage(b800) +b1400 = sitk.ReadImage(path_b1400, sitk.sitkFloat32) +b1400 = sitk.GetArrayFromImage(b1400) +adc = sitk.ReadImage(path_adc, sitk.sitkFloat32) +adc = sitk.GetArrayFromImage(adc) + +def show_img(greyscale_img): + fig = plt.figure() + plt.imshow(greyscale_img) + plt.axis('on') + path = f"iets.png" + fig.savefig(path, dpi=300, bbox_inches='tight') + +def calc_adc(b50, b400, b800): + mean_dwi = (50 + 400 + 800) / 3 + mean_si = np.divide(np.add(np.add(np.log(b50), np.log(b400)), np.log(b800)), 3) + + denominator = np.multiply((50 - mean_dwi), np.subtract(np.log(b50), mean_si)) + np.multiply((400 - mean_dwi), np.subtract(np.log(b400), mean_si)) + np.multiply((800 - mean_dwi), np.subtract(np.log(b800), mean_si)) + numerator = np.power((50 - mean_dwi), 2) + np.power((400 - mean_dwi), 2) + np.power((800 - mean_dwi), 2) + adc = np.divide(denominator, numerator) + return adc * -1000000 + +def calc_adc_1(b50,b800): + mean_dwi = (50 + 800) / 2 + mean_si = np.divide(np.add(np.log(b50), np.log(b800)), 2) + + denominator = np.multiply((50 - mean_dwi), np.subtract(np.log(b50), mean_si)) + np.multiply((800 - mean_dwi), np.subtract(np.log(b800), mean_si)) + numerator = np.power((50 - mean_dwi), 2) + np.power((800 - mean_dwi), 2) + adc = np.divide(denominator, numerator) + return adc * -1000000 + +def calc_adc_2(b50,b400): + mean_dwi = (50 + 400) / 2 + mean_si = np.divide(np.add(np.log(b50), np.log(b400)), 2) + + denominator = np.multiply((50 - mean_dwi), np.subtract(np.log(b50), mean_si)) + np.multiply((400 - mean_dwi), np.subtract(np.log(b400), mean_si)) + numerator = np.power((50 - mean_dwi), 2) + np.power((400 - mean_dwi), 2) + adc = np.divide(denominator, numerator) + return adc * -1000000 + +def calc_adc_3(b400,b800): + mean_dwi = (400 + 800) / 2 + mean_si = np.divide(np.add(np.log(b400), np.log(b800)), 2) + + denominator = np.multiply((400 - mean_dwi), np.subtract(np.log(b400), mean_si)) + np.multiply((800 - mean_dwi), np.subtract(np.log(b800), mean_si)) + numerator = np.power((400 - mean_dwi), 2) + np.power((800 - mean_dwi), 2) + adc = np.divide(denominator, numerator) + return adc * -1000000 + +def calc_high_b(b_value_high,b_value,b_image,ADC_map): + high_b = np.multiply(b_image, np.log(np.multiply(np.subtract(b_value,b_value_high), ADC_map))) + return high_b + + +high_b = sitk.GetImageFromArray(b1400) +sitk.WriteImage(high_b, f"./../train_output/adc_exp/b1400_true.nii.gz") + +high_b = calc_high_b(1400,50,b50,adc) +high_b = sitk.GetImageFromArray(high_b) +sitk.WriteImage(high_b, f"./../train_output/adc_exp/b1400_ref_b50.nii.gz") + +high_b = calc_high_b(1400,400,b400,adc) +high_b = sitk.GetImageFromArray(high_b) +sitk.WriteImage(high_b, f"./../train_output/adc_exp/b1400_ref_b400.nii.gz") + +high_b = calc_high_b(1400,800,b800,adc) +high_b = sitk.GetImageFromArray(high_b) +sitk.WriteImage(high_b, f"./../train_output/adc_exp/b1400_ref_b800.nii.gz") + +# b50 = sitk.GetImageFromArray(b50) +# sitk.WriteImage(b50, "./../train_output/adc_exp/b50_true.nii.gz") + +# b400 = sitk.GetImageFromArray(b400) +# sitk.WriteImage(b400, "./../train_output/adc_exp/b400_true.nii.gz") + +# b800 = sitk.GetImageFromArray(b800) +# sitk.WriteImage(b800, "./../train_output/adc_exp/b800_true.nii.gz") + +# b1400 = sitk.GetImageFromArray(b1400) +# sitk.WriteImage(b1400,f"./../train_output/adc_exp/b1400_true.nii.gz") + + +# adc = sitk.GetImageFromArray(adc) +# sitk.WriteImage(adc, f"adc_true.nii.gz") + +# adc = calc_adc(b50,b400,b800) +# print("calculated with 3 adc shape:",adc.shape) +# adc = sitk.GetImageFromArray(adc) +# sitk.WriteImage(adc, f"adc_calc_b50_b400_b800.nii.gz") + +# adc = calc_adc_1(b50,b800) +# print("calculated with 2 adc shape:",adc.shape) +# adc = sitk.GetImageFromArray(adc) +# sitk.WriteImage(adc, f"adc_calc_b50_b800.nii.gz") + +# adc = calc_adc_2(b50,b400) +# print("calculated with 2 adc shape:",adc.shape) +# adc = sitk.GetImageFromArray(adc) +# sitk.WriteImage(adc, f"adc_calc_b50_b400.nii.gz") + +# adc = calc_adc_3(b400,b800) +# print("calculated with 2 adc shape:",adc.shape) +# adc = sitk.GetImageFromArray(adc) +# sitk.WriteImage(adc, f"adc_calc_b400_b800.nii.gz") + diff --git a/scripts/tmp.py b/scripts/tmp.py new file mode 100755 index 0000000..0d91f0f --- /dev/null +++ b/scripts/tmp.py @@ -0,0 +1,13 @@ +import matplotlib.pyplot as plt +import matplotlib.ticker as ticker + +x = [0,5,9,10,15] +y = [0,1,2,3,4] + +tick_spacing = 1 + +fig, ax = plt.subplots(1,1) +ax.plot(x,y) +ax.set_xticks([0,1,2,5,8,9,10,11,20]) +ax.xaxis.set_major_locator(ticker.MultipleLocator(tick_spacing)) +plt.show() \ No newline at end of file diff --git a/src/sfransen/DWI_exp/__init__.py b/src/sfransen/DWI_exp/__init__.py index c0f8c62..6a0390a 100755 --- a/src/sfransen/DWI_exp/__init__.py +++ b/src/sfransen/DWI_exp/__init__.py @@ -1,4 +1,3 @@ -print("de juiste init file") from .batchgenerator import * from .callbacks import * from .helpers import * diff --git a/src/sfransen/FROC/cal_froc_from_np.py b/src/sfransen/FROC/cal_froc_from_np.py index 47d28fb..11c4e6f 100755 --- a/src/sfransen/FROC/cal_froc_from_np.py +++ b/src/sfransen/FROC/cal_froc_from_np.py @@ -105,7 +105,6 @@ def preprocess_softmax_dynamic(softmax: np.ndarray, # set dynamic threshold to half the max threshold = max_prob / dynamic_threshold_factor - # extract blobs for dynamix threshold all_hard_blobs, _, _ = preprocess_softmax_static(working_softmax, threshold=threshold, min_voxels_detection=min_voxels_detection, @@ -188,14 +187,13 @@ def preprocess_softmax(softmax: np.ndarray, def evaluate( y_true: np.ndarray, y_pred: np.ndarray, - min_overlap=0.10, + min_overlap=0.02, overlap_func: str = 'DSC', case_confidence: str = 'max', multiple_lesion_candidates_selection_criteria='overlap', allow_unmatched_candidates_with_minimal_overlap=True, flat: Optional[bool] = None ) -> Dict[str, Any]: - # Make list out of numpy array so that it can be mapped in parallel with multiple CPUs. y_true_list = y_true #[y_true[mri_idx] for mri_idx in range(y_true.shape[0])] y_pred_list = y_pred #[y_pred[mri_idx] for mri_idx in range(y_pred.shape[0])] @@ -431,7 +429,6 @@ def evaluate_case( confidences, indexed_pred = parse_detection_map(y_pred) lesion_candidates_best_overlap: Dict[str, float] = {} - if y_true.any(): # for each malignant scan labeled_gt, num_gt_lesions = ndimage.label(y_true, np.ones((3, 3, 3))) @@ -464,7 +461,7 @@ def evaluate_case( 'overlap': overlap_score, }) print(lesion_candidates_for_target_gt) - + print("min benodigde overlap:", min_overlap) if len(lesion_candidates_for_target_gt) == 0: # no lesion candidate matched with GT mask. Add FN. y_list.append((1, 0., 0.)) diff --git a/src/sfransen/FROC/froc.py b/src/sfransen/FROC/froc.py index 16c8297..dcabc4c 100755 --- a/src/sfransen/FROC/froc.py +++ b/src/sfransen/FROC/froc.py @@ -30,10 +30,10 @@ try: except ImportError: pass -from image_utils import ( +from sfransen.FROC.image_utils import ( resize_image_with_crop_or_pad, read_label, read_prediction ) -from analysis_utils import ( +from sfransen.FROC.analysis_utils import ( parse_detection_map, calculate_iou, calculate_dsc ) diff --git a/src/sfransen/Saliency/__init__.py b/src/sfransen/Saliency/__init__.py new file mode 100755 index 0000000..37a879a --- /dev/null +++ b/src/sfransen/Saliency/__init__.py @@ -0,0 +1,2 @@ +from .base import * +from .integrated_gradients import * diff --git a/src/sfransen/Saliency/integrated_gradients.py b/src/sfransen/Saliency/integrated_gradients.py index f47bfe7..bc34d4b 100755 --- a/src/sfransen/Saliency/integrated_gradients.py +++ b/src/sfransen/Saliency/integrated_gradients.py @@ -1,12 +1,11 @@ import numpy as np import tensorflow as tf from tensorflow.keras.applications import densenet - -from base import SaliencyMap +from sfransen.Saliency.base import SaliencyMap class IntegratedGradients(SaliencyMap): - def get_mask(self, image, baseline=None, num_steps=2): + def get_mask(self, image, baseline=None, num_steps=4): """Computes Integrated Gradients for a predicted label. Args: @@ -28,7 +27,6 @@ class IntegratedGradients(SaliencyMap): baseline = np.zeros(img_size).astype(np.float32) else: baseline = baseline.astype(np.float32) - print(">>>> step ONE completed") img_input = image top_pred_idx = self.get_top_predicted_idx(image) @@ -37,27 +35,20 @@ class IntegratedGradients(SaliencyMap): for i in range(num_steps + 1) ] interpolated_image = np.vstack(interpolated_image).astype(np.float32) - print(">>>> step TWO completed") grads = [] for i, img in enumerate(interpolated_image): - print("number of image:",i) - print("size of image:",np.shape(img)) + print(f"interpolation step:",i," out of {num_steps}") img = tf.expand_dims(img, axis=0) grad = self.get_gradients(img) - print("size of grad is:",np.shape(grad)) grads.append(grad[0]) grads = tf.convert_to_tensor(grads, dtype=tf.float32) - print(">>>> step THREE completed") # 4. Approximate the integral using the trapezoidal rule grads = (grads[:-1] + grads[1:]) / 2.0 avg_grads = tf.reduce_mean(grads, axis=0) - # tf.reduce_mean(grads, axis=(0, 1, 2, 3)) - print(">>>> step FOUR completed") # 5. Calculate integrated gradients and return integrated_grads = (img_input - baseline) * avg_grads - print(">>>> step FIVE completed") return integrated_grads diff --git a/src/sfransen/load_images.py b/src/sfransen/load_images.py new file mode 100755 index 0000000..96b6935 --- /dev/null +++ b/src/sfransen/load_images.py @@ -0,0 +1,30 @@ +from typing import List +import SimpleITK as sitk +from sfransen.DWI_exp.helpers import * + +def load_images_parrallel( + image_paths: str, + seq: str, + target_shape: List[int], + target_space = List[float]): + + img_s = sitk.ReadImage(image_paths, sitk.sitkFloat32) + + #resample + mri_tra_s = resample(img_s, + min_shape=target_shape, + method=sitk.sitkNearestNeighbor, + new_spacing=target_space) + + #center crop + mri_tra_s = center_crop(mri_tra_s, shape=target_shape) + #normalize + if seq != 'seg': + filter = sitk.NormalizeImageFilter() + mri_tra_s = filter.Execute(mri_tra_s) + else: + filter = sitk.BinaryThresholdImageFilter() + filter.SetLowerThreshold(1.0) + mri_tra_s = filter.Execute(mri_tra_s) + + return sitk.GetArrayFromImage(mri_tra_s).T \ No newline at end of file