update
This commit is contained in:
parent
e3b84db978
commit
49b18fe7f0
@ -27,6 +27,9 @@ parser.add_argument('--series', '-s',
|
|||||||
help='List of series to include')
|
help='List of series to include')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
## info: adjust number of interpolation steps to 10 in scr/**/saliency/integrated_gradients.py
|
||||||
|
|
||||||
######## CUDA ################
|
######## CUDA ################
|
||||||
os.environ["CUDA_VISIBLE_DEVICES"] = "2"
|
os.environ["CUDA_VISIBLE_DEVICES"] = "2"
|
||||||
|
|
||||||
|
160
scripts/20.saliency_exp.py
Executable file
160
scripts/20.saliency_exp.py
Executable file
@ -0,0 +1,160 @@
|
|||||||
|
import argparse
|
||||||
|
from os import path
|
||||||
|
import SimpleITK as sitk
|
||||||
|
import tensorflow as tf
|
||||||
|
from tensorflow.keras.models import load_model
|
||||||
|
import numpy as np
|
||||||
|
import os
|
||||||
|
|
||||||
|
from sfransen.utils_quintin import *
|
||||||
|
from sfransen.DWI_exp import preprocess
|
||||||
|
from sfransen.DWI_exp.helpers import *
|
||||||
|
from sfransen.DWI_exp.callbacks import dice_coef
|
||||||
|
from sfransen.DWI_exp.losses import weighted_binary_cross_entropy
|
||||||
|
|
||||||
|
from sfransen.FROC.blob_preprocess import *
|
||||||
|
from sfransen.FROC.cal_froc_from_np import *
|
||||||
|
from sfransen.load_images import load_images_parrallel
|
||||||
|
from sfransen.Saliency.base import *
|
||||||
|
from sfransen.Saliency.integrated_gradients import *
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description='Calculate the froc metrics and store in froc_metrics.yml')
|
||||||
|
parser.add_argument('-experiment',
|
||||||
|
help='Title of experiment')
|
||||||
|
parser.add_argument('--series', '-s',
|
||||||
|
metavar='[series_name]', required=True, nargs='+',
|
||||||
|
help='List of series to include')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
def print_p(*args, **kwargs):
|
||||||
|
"""
|
||||||
|
Shorthand for print(..., flush=True)
|
||||||
|
Useful on HPC cluster where output has buffered writes.
|
||||||
|
"""
|
||||||
|
print(*args, **kwargs, flush=True)
|
||||||
|
|
||||||
|
######## CUDA ################
|
||||||
|
os.environ["CUDA_VISIBLE_DEVICES"] = "2"
|
||||||
|
|
||||||
|
######## constants #############
|
||||||
|
SERIES = args.series
|
||||||
|
series_ = '_'.join(args.series)
|
||||||
|
EXPERIMENT = args.experiment
|
||||||
|
|
||||||
|
DATA_DIR = "./../data/Nijmegen paths/"
|
||||||
|
TARGET_SPACING = (0.5, 0.5, 3)
|
||||||
|
INPUT_SHAPE = (192, 192, 24, len(SERIES))
|
||||||
|
IMAGE_SHAPE = INPUT_SHAPE[:3]
|
||||||
|
|
||||||
|
image_paths = {}
|
||||||
|
for s in SERIES:
|
||||||
|
with open(path.join(DATA_DIR, f"{s}.txt"), 'r') as f:
|
||||||
|
image_paths[s] = [l.strip() for l in f.readlines()]
|
||||||
|
with open(path.join(DATA_DIR, f"seg.txt"), 'r') as f:
|
||||||
|
seg_paths = [l.strip() for l in f.readlines()]
|
||||||
|
num_images = len(seg_paths)
|
||||||
|
|
||||||
|
max_saliency_values = []
|
||||||
|
for fold in range(5):
|
||||||
|
print_p("fold:",fold)
|
||||||
|
|
||||||
|
# model path
|
||||||
|
MODEL_PATH = f'./../train_output/{EXPERIMENT}_{series_}_{fold}/models/{EXPERIMENT}_{series_}_{fold}.h5'
|
||||||
|
YAML_DIR = f'./../train_output/{EXPERIMENT}_{series_}_{fold}'
|
||||||
|
IMAGE_DIR = f'./../train_output/{EXPERIMENT}_{series_}_{fold}'
|
||||||
|
|
||||||
|
# test indices
|
||||||
|
DATA_SPLIT_INDEX = read_yaml_to_dict(f'./../data/Nijmegen paths/train_val_test_idxs_{fold}.yml')
|
||||||
|
TEST_INDEX_IMGS = DATA_SPLIT_INDEX['test_set0']
|
||||||
|
|
||||||
|
for img_idx in TEST_INDEX_IMGS[:10]:
|
||||||
|
print_p("img_idx:",img_idx)
|
||||||
|
images = []
|
||||||
|
images_list = []
|
||||||
|
segmentations = []
|
||||||
|
saliency_map = []
|
||||||
|
|
||||||
|
# Read and preprocess each of the paths for each series, and the segmentations.
|
||||||
|
# print('images number',[TEST_INDEX[img_idx]])
|
||||||
|
img_s = {f'{s}': sitk.ReadImage(image_paths[s][img_idx], sitk.sitkFloat32) for s in SERIES}
|
||||||
|
seg_s = sitk.ReadImage(seg_paths[img_idx], sitk.sitkFloat32)
|
||||||
|
img_n, seg_n = preprocess(img_s, seg_s,
|
||||||
|
shape=IMAGE_SHAPE, spacing=TARGET_SPACING)
|
||||||
|
for seq in img_n:
|
||||||
|
images.append(img_n[f'{seq}'])
|
||||||
|
images_list.append(images)
|
||||||
|
images = []
|
||||||
|
segmentations.append(seg_n)
|
||||||
|
|
||||||
|
images_list = np.transpose(images_list, (0, 2, 3, 4, 1))
|
||||||
|
|
||||||
|
|
||||||
|
########### load module ##################
|
||||||
|
# print(' >>>>>>> LOAD MODEL <<<<<<<<<')
|
||||||
|
|
||||||
|
dependencies = {
|
||||||
|
'dice_coef': dice_coef,
|
||||||
|
'weighted_cross_entropy_fn':weighted_binary_cross_entropy
|
||||||
|
}
|
||||||
|
reconstructed_model = load_model(MODEL_PATH, custom_objects=dependencies)
|
||||||
|
# reconstructed_model.summary(line_length=120)
|
||||||
|
|
||||||
|
# make predictions on all TEST_INDEX
|
||||||
|
# print(' >>>>>>> START prediction <<<<<<<<<')
|
||||||
|
predictions_blur = reconstructed_model.predict(images_list, batch_size=1)
|
||||||
|
|
||||||
|
|
||||||
|
############# preprocess #################
|
||||||
|
# preprocess predictions by removing the blur and making individual blobs
|
||||||
|
# print('>>>>>>>> START preprocess')
|
||||||
|
|
||||||
|
def move_dims(arr):
|
||||||
|
# UMCG numpy dimensions convention: dims = (batch, width, heigth, depth)
|
||||||
|
# Joeran numpy dimensions convention: dims = (batch, depth, heigth, width)
|
||||||
|
arr = np.moveaxis(arr, 3, 1)
|
||||||
|
arr = np.moveaxis(arr, 3, 2)
|
||||||
|
return arr
|
||||||
|
|
||||||
|
# Joeran has his numpy arrays ordered differently.
|
||||||
|
|
||||||
|
predictions_blur = move_dims(np.squeeze(predictions_blur,axis=4))
|
||||||
|
segmentations = move_dims(segmentations)
|
||||||
|
# predictions = [preprocess_softmax(pred, threshold="dynamic")[0] for pred in predictions_blur]
|
||||||
|
predictions = predictions_blur
|
||||||
|
# print("the size of predictions is:",np.shape(predictions))
|
||||||
|
# Remove outer edges
|
||||||
|
zeros = np.zeros(np.shape(predictions))
|
||||||
|
test = predictions[:,2:-2,2:190,2:190]
|
||||||
|
zeros[:,2:-2,2:190,2:190] = test
|
||||||
|
predictions = zeros
|
||||||
|
# print(np.shape(predictions))
|
||||||
|
######### Build Saliency heatmap ##############
|
||||||
|
# print(' >>>>>>> Build saliency map <<<<<<<<<')
|
||||||
|
|
||||||
|
ig = IntegratedGradients(reconstructed_model)
|
||||||
|
for img_idx in range(len(images_list)):
|
||||||
|
# input_img = np.resize(images_list[img_idx],(1,192,192,24,len(SERIES)))
|
||||||
|
saliency_map.append(ig.get_mask(images_list).numpy())
|
||||||
|
|
||||||
|
# print("size saliency map",np.shape(saliency_map))
|
||||||
|
|
||||||
|
idx_max = np.argmax(np.mean(np.mean(np.mean(np.squeeze(saliency_map),axis=0),axis=0),axis=0))
|
||||||
|
|
||||||
|
max_saliency_values.append(idx_max)
|
||||||
|
print_p("max_saliency_values:",max_saliency_values)
|
||||||
|
|
||||||
|
t2_max = sum(map(lambda x : x == 0, max_saliency_values))
|
||||||
|
dwi_max = sum(map(lambda x : x == 1, max_saliency_values))
|
||||||
|
adc_max = sum(map(lambda x : x == 2, max_saliency_values))
|
||||||
|
|
||||||
|
print_p(f"max value in t2 is: {t2_max}")
|
||||||
|
print_p(f"max value in dwi is: {dwi_max}")
|
||||||
|
print_p(f"max value in adc is: {adc_max}")
|
||||||
|
|
||||||
|
# np.save(f'{YAML_DIR}/saliency_new23',saliency_map)
|
||||||
|
# np.save(f'{YAML_DIR}/images_list_new23',images_list)
|
||||||
|
# np.save(f'{YAML_DIR}/segmentations_new23',segmentations)
|
||||||
|
# np.save(f'{YAML_DIR}/predictions_new23',predictions)
|
||||||
|
|
||||||
|
|
149
scripts/21.idx_lowest_predictions.py
Executable file
149
scripts/21.idx_lowest_predictions.py
Executable file
@ -0,0 +1,149 @@
|
|||||||
|
from inspect import _ParameterKind
|
||||||
|
import SimpleITK as sitk
|
||||||
|
import tensorflow as tf
|
||||||
|
from tensorflow.keras.models import load_model
|
||||||
|
from focal_loss import BinaryFocalLoss
|
||||||
|
import numpy as np
|
||||||
|
import multiprocessing
|
||||||
|
from functools import partial
|
||||||
|
import os
|
||||||
|
from os import path
|
||||||
|
from tqdm import tqdm
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
from sfransen.utils_quintin import *
|
||||||
|
from sfransen.DWI_exp.helpers import *
|
||||||
|
from sfransen.DWI_exp.preprocessing_function import preprocess
|
||||||
|
from sfransen.DWI_exp.callbacks import dice_coef
|
||||||
|
#from sfransen.FROC.blob_preprocess import *
|
||||||
|
from sfransen.FROC.cal_froc_from_np import *
|
||||||
|
from sfransen.load_images import load_images_parrallel
|
||||||
|
from sfransen.DWI_exp.losses import weighted_binary_cross_entropy
|
||||||
|
from umcglib.froc import *
|
||||||
|
from umcglib.binarize import dynamic_threshold
|
||||||
|
|
||||||
|
|
||||||
|
def print_p(*args, **kwargs):
|
||||||
|
"""
|
||||||
|
Shorthand for print(..., flush=True)
|
||||||
|
Useful on HPC cluster where output has buffered writes.
|
||||||
|
"""
|
||||||
|
print(*args, **kwargs, flush=True)
|
||||||
|
|
||||||
|
|
||||||
|
######## CUDA ################
|
||||||
|
os.environ["CUDA_VISIBLE_DEVICES"] = "2"
|
||||||
|
N_CPUS = 12
|
||||||
|
|
||||||
|
|
||||||
|
DATA_DIR = "./../data/Nijmegen paths/"
|
||||||
|
TARGET_SPACING = (0.5, 0.5, 3)
|
||||||
|
INPUT_SHAPE = (192, 192, 24, 3)
|
||||||
|
IMAGE_SHAPE = INPUT_SHAPE[:3]
|
||||||
|
|
||||||
|
final_table = {}
|
||||||
|
difference = {}
|
||||||
|
for fold in range(5):
|
||||||
|
|
||||||
|
DATA_SPLIT_INDEX = read_yaml_to_dict(f'./../data/Nijmegen paths/train_val_test_idxs_{fold}.yml')
|
||||||
|
TEST_INDEX = DATA_SPLIT_INDEX['test_set0']
|
||||||
|
|
||||||
|
for img_idx in TEST_INDEX:
|
||||||
|
|
||||||
|
for model in ['b800','b400']:
|
||||||
|
|
||||||
|
image_paths = {}
|
||||||
|
predictions_added = []
|
||||||
|
segmentations_added = []
|
||||||
|
images = []
|
||||||
|
images_list = []
|
||||||
|
segmentations = []
|
||||||
|
|
||||||
|
if model is 'b800':
|
||||||
|
MODEL_PATH = f'./../train_output/calc_exp_t2_b1400calc2_adccalc2_{fold}/models/calc_exp_t2_b1400calc2_adccalc2_{fold}.h5'
|
||||||
|
# YAML_DIR = f'./../train_output/calc_exp_t2_b1400calc2_adccalc2_{fold}'
|
||||||
|
# IMAGE_DIR = f'./../train_output/calc_exp_t2_b1400calc2_adccalc2_{fold}'
|
||||||
|
SERIES = ['t2','b1400calc2','adccalc2']
|
||||||
|
if model is 'b400':
|
||||||
|
MODEL_PATH = f'./../train_output/calc_exp_t2_b1400calc3_adccalc3_{fold}/models/calc_exp_t2_b1400calc3_adccalc3_{fold}.h5'
|
||||||
|
SERIES = ['t2','b1400calc3','adccalc3']
|
||||||
|
|
||||||
|
for s in SERIES:
|
||||||
|
with open(path.join(DATA_DIR, f"{s}.txt"), 'r') as f:
|
||||||
|
image_paths[s] = [l.strip() for l in f.readlines()]
|
||||||
|
with open(path.join(DATA_DIR, f"seg.txt"), 'r') as f:
|
||||||
|
seg_paths = [l.strip() for l in f.readlines()]
|
||||||
|
num_images = len(seg_paths)
|
||||||
|
|
||||||
|
pat_id = os.path.basename(os.path.normpath(seg_paths[img_idx]))[:-7]
|
||||||
|
|
||||||
|
# print_p("pat_idx:",pat_id)
|
||||||
|
|
||||||
|
# Read and preprocess each of the paths for each series, and the segmentations.
|
||||||
|
# print('images number',[TEST_INDEX[img_idx]])
|
||||||
|
img_s = {f'{s}': sitk.ReadImage(image_paths[s][img_idx], sitk.sitkFloat32) for s in SERIES}
|
||||||
|
seg_s = sitk.ReadImage(seg_paths[img_idx], sitk.sitkFloat32)
|
||||||
|
img_n, seg_n = preprocess(img_s, seg_s,
|
||||||
|
shape=IMAGE_SHAPE, spacing=TARGET_SPACING)
|
||||||
|
for seq in img_n:
|
||||||
|
images.append(img_n[f'{seq}'])
|
||||||
|
images_list.append(images)
|
||||||
|
images = []
|
||||||
|
segmentations.append(seg_n)
|
||||||
|
|
||||||
|
images_list = np.transpose(images_list, (0, 2, 3, 4, 1))
|
||||||
|
|
||||||
|
########### load module ##################
|
||||||
|
# print(' >>>>>>> LOAD MODEL <<<<<<<<<')
|
||||||
|
|
||||||
|
dependencies = {
|
||||||
|
'dice_coef': dice_coef,
|
||||||
|
'weighted_cross_entropy_fn':weighted_binary_cross_entropy
|
||||||
|
}
|
||||||
|
reconstructed_model = load_model(MODEL_PATH, custom_objects=dependencies)
|
||||||
|
# reconstructed_model.summary(line_length=120)
|
||||||
|
|
||||||
|
# make predictions on all TEST_INDEX
|
||||||
|
# print(' >>>>>>> START prediction <<<<<<<<<')
|
||||||
|
predictions_blur = reconstructed_model.predict(images_list, batch_size=1)
|
||||||
|
|
||||||
|
############# preprocess #################
|
||||||
|
# preprocess predictions by removing the blur and making individual blobs
|
||||||
|
# print('>>>>>>>> START preprocess')
|
||||||
|
def move_dims(arr):
|
||||||
|
# UMCG numpy dimensions convention: dims = (batch, width, heigth, depth)
|
||||||
|
# Joeran numpy dimensions convention: dims = (batch, depth, heigth, width)
|
||||||
|
arr = np.moveaxis(arr, 3, 1)
|
||||||
|
arr = np.moveaxis(arr, 3, 2)
|
||||||
|
return arr
|
||||||
|
|
||||||
|
# Joeran has his numpy arrays ordered differently.
|
||||||
|
predictions_blur = move_dims(np.squeeze(predictions_blur,axis=4))
|
||||||
|
segmentations = move_dims(segmentations)
|
||||||
|
# predictions = [preprocess_softmax(pred, threshold="dynamic")[0] for pred in predictions_blur]
|
||||||
|
predictions = predictions_blur
|
||||||
|
# print("the size of predictions is:",np.shape(predictions))
|
||||||
|
# Remove outer edges
|
||||||
|
zeros = np.zeros(np.shape(predictions))
|
||||||
|
test = predictions[:,2:-2,2:190,2:190]
|
||||||
|
zeros[:,2:-2,2:190,2:190] = test
|
||||||
|
predictions = zeros
|
||||||
|
|
||||||
|
#make list of worst patient predictions
|
||||||
|
|
||||||
|
|
||||||
|
if model is 'b800':
|
||||||
|
final_table[pat_id] = [np.max(predictions)]
|
||||||
|
print_p(f'Max prediction of {pat_id} in b800 is {np.max(predictions)}')
|
||||||
|
if model is 'b400':
|
||||||
|
final_table[pat_id].append(np.max(predictions))
|
||||||
|
print_p(f'Max prediction of {pat_id} in b400 is {np.max(predictions)}')
|
||||||
|
|
||||||
|
difference[pat_id] = abs(np.diff(final_table[pat_id]))
|
||||||
|
sorted_difference = {k: v for k, v in sorted(difference.items(), key=lambda item: item[1])}
|
||||||
|
print_p(f'>>{fold}<<',sorted_difference)
|
||||||
|
|
||||||
|
|
||||||
|
sorted_differences = {k: v for k, v in sorted(difference.items(), key=lambda item: item[1])}
|
||||||
|
print_p('>>>>>>>>>>>>>>>>>>>>>>>>>>>><<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<')
|
||||||
|
print_p(sorted_differences)
|
31
scripts/test3.py
Executable file
31
scripts/test3.py
Executable file
@ -0,0 +1,31 @@
|
|||||||
|
from glob import glob
|
||||||
|
from os.path import normpath, basename
|
||||||
|
import SimpleITK as sitk
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def get_paths(main_dir):
|
||||||
|
all_niftis = glob(main_dir, recursive=True)
|
||||||
|
|
||||||
|
dwis_b800 = [i for i in all_niftis if ("diff" in i.lower() or "dwi" in i.lower()) and ("b-800" in i.lower() or "b800" in i.lower())]
|
||||||
|
|
||||||
|
dwis_b400 = [i for i in all_niftis if ("diff" in i.lower() or "dwi" in i.lower()) and ("b-400" in i.lower() or "b400" in i.lower())]
|
||||||
|
|
||||||
|
return dwis_b800, dwis_b400
|
||||||
|
|
||||||
|
|
||||||
|
pat_numbers = ['pat0132','pat0091','pat0352','pat0844','pat1006','pat0406','pat0128','pat0153','pat0062','pat0758','pat0932','pat0248','pat0129','pat0429','pat0181','pat0063','pat0674','pat0176','pat0366','pat0082']
|
||||||
|
load_path = '../../datasets/radboud_new/{pat_number}/2016/**/*.nii.gz'
|
||||||
|
|
||||||
|
for idx, pat_number in enumerate(pat_numbers):
|
||||||
|
dwis_b800,dwis_b400 = get_paths(f'../../datasets/radboud_new/{pat_number}/2016/**/*.nii.gz')
|
||||||
|
|
||||||
|
# load
|
||||||
|
dwi_b800 = sitk.ReadImage(dwis_b800, sitk.sitkFloat32)
|
||||||
|
dwi_b400 = sitk.ReadImage(dwis_b400, sitk.sitkFloat32)
|
||||||
|
# write
|
||||||
|
output_path_b800 = f'../temp/check_by_derya/{idx}_{pat_number}_b800.nii.gz'
|
||||||
|
output_path_b400 = f'../temp/check_by_derya/{idx}_{pat_number}_b400.nii.gz'
|
||||||
|
sitk.WriteImage(dwi_b800, output_path_b800)
|
||||||
|
sitk.WriteImage(dwi_b400, output_path_b400)
|
||||||
|
|
@ -44,9 +44,9 @@ class SaliencyMap():
|
|||||||
with tf.GradientTape() as tape:
|
with tf.GradientTape() as tape:
|
||||||
tape.watch(image)
|
tape.watch(image)
|
||||||
preds = self.model(image)
|
preds = self.model(image)
|
||||||
print("get_gradients, size of preds",np.shape(preds))
|
# print("get_gradients, size of preds",np.shape(preds))
|
||||||
top_class = preds[:]
|
top_class = preds[:]
|
||||||
print("get_gradients, size of top_class",np.shape(top_class))
|
# print("get_gradients, size of top_class",np.shape(top_class))
|
||||||
|
|
||||||
|
|
||||||
grads = tape.gradient(top_class, image)
|
grads = tape.gradient(top_class, image)
|
||||||
|
86
src/sfransen/Saliency/heatmap.py
Executable file
86
src/sfransen/Saliency/heatmap.py
Executable file
@ -0,0 +1,86 @@
|
|||||||
|
import os
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image
|
||||||
|
import matplotlib
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import scipy.ndimage as ndimage
|
||||||
|
|
||||||
|
class HeatMap:
|
||||||
|
def __init__(self,image,heat_map,gaussian_std=10):
|
||||||
|
#if image is numpy array
|
||||||
|
if isinstance(image,np.ndarray):
|
||||||
|
height = image.shape[0]
|
||||||
|
width = image.shape[1]
|
||||||
|
self.image = image
|
||||||
|
else:
|
||||||
|
#PIL open the image path, record the height and width
|
||||||
|
image = Image.open(image)
|
||||||
|
width, height = image.size
|
||||||
|
self.image = image
|
||||||
|
|
||||||
|
#Convert numpy heat_map values into image formate for easy upscale
|
||||||
|
#Rezie the heat_map to the size of the input image
|
||||||
|
#Apply the gausian filter for smoothing
|
||||||
|
#Convert back to numpy
|
||||||
|
heatmap_image = Image.fromarray(heat_map*255)
|
||||||
|
heatmap_image_resized = heatmap_image.resize((width,height))
|
||||||
|
heatmap_image_resized = ndimage.gaussian_filter(heatmap_image_resized,
|
||||||
|
sigma=(gaussian_std, gaussian_std),
|
||||||
|
order=0)
|
||||||
|
heatmap_image_resized = np.asarray(heatmap_image_resized)
|
||||||
|
self.heat_map = heatmap_image_resized
|
||||||
|
|
||||||
|
#Plot the figure
|
||||||
|
def plot(self,transparency=0.7,color_map='bwr',
|
||||||
|
show_axis=False, show_original=False, show_colorbar=False,width_pad=0):
|
||||||
|
|
||||||
|
#If show_original is True, then subplot first figure as orginal image
|
||||||
|
#Set x,y to let the heatmap plot in the second subfigure,
|
||||||
|
#otherwise heatmap will plot in the first sub figure
|
||||||
|
if show_original:
|
||||||
|
plt.subplot(1, 2, 1)
|
||||||
|
if not show_axis:
|
||||||
|
plt.axis('off')
|
||||||
|
plt.imshow(self.image,cmap='gray')
|
||||||
|
x,y=2,2
|
||||||
|
else:
|
||||||
|
x,y=1,1
|
||||||
|
|
||||||
|
#Plot the heatmap
|
||||||
|
plt.subplot(1,x,y)
|
||||||
|
if not show_axis:
|
||||||
|
plt.axis('off')
|
||||||
|
plt.imshow(self.image,cmap='gray')
|
||||||
|
plt.imshow(self.heat_map/255, alpha=transparency, cmap=color_map)
|
||||||
|
if show_colorbar:
|
||||||
|
plt.colorbar()
|
||||||
|
plt.tight_layout(w_pad=width_pad)
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
###Save the figure
|
||||||
|
def save(self,filename,format='png',save_path=os.getcwd(),
|
||||||
|
transparency=0.7,color_map='bwr',width_pad = -10,
|
||||||
|
show_axis=False, show_original=False, show_colorbar=False, **kwargs):
|
||||||
|
if show_original:
|
||||||
|
plt.subplot(1, 2, 1)
|
||||||
|
if not show_axis:
|
||||||
|
plt.axis('off')
|
||||||
|
plt.imshow(self.image,cmap='gray')
|
||||||
|
x,y=2,2
|
||||||
|
else:
|
||||||
|
x,y=1,1
|
||||||
|
|
||||||
|
#Plot the heatmap
|
||||||
|
plt.subplot(1,x,y)
|
||||||
|
if not show_axis:
|
||||||
|
plt.axis('off')
|
||||||
|
plt.imshow(self.image,cmap='gray')
|
||||||
|
plt.imshow(self.heat_map/255, alpha=transparency, cmap=color_map, caxis = [min(nonzeros(self.image)) max(nonzeros(self.image))])
|
||||||
|
if show_colorbar:
|
||||||
|
plt.colorbar()
|
||||||
|
plt.tight_layout(w_pad=width_pad)
|
||||||
|
plt.savefig(os.path.join(save_path,filename+'.'+format),
|
||||||
|
format=format,
|
||||||
|
bbox_inches='tight',
|
||||||
|
pad_inches = 0, **kwargs)
|
||||||
|
print('{}.{} has been successfully saved to {}'.format(filename,format,save_path))
|
@ -5,7 +5,7 @@ from sfransen.Saliency.base import SaliencyMap
|
|||||||
|
|
||||||
class IntegratedGradients(SaliencyMap):
|
class IntegratedGradients(SaliencyMap):
|
||||||
|
|
||||||
def get_mask(self, image, baseline=None, num_steps=4):
|
def get_mask(self, image, baseline=None, num_steps=3):
|
||||||
"""Computes Integrated Gradients for a predicted label.
|
"""Computes Integrated Gradients for a predicted label.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -38,7 +38,7 @@ class IntegratedGradients(SaliencyMap):
|
|||||||
|
|
||||||
grads = []
|
grads = []
|
||||||
for i, img in enumerate(interpolated_image):
|
for i, img in enumerate(interpolated_image):
|
||||||
print(f"interpolation step:",i," out of {num_steps}")
|
# print(f"interpolation step:",i,f" out of {num_steps}")
|
||||||
img = tf.expand_dims(img, axis=0)
|
img = tf.expand_dims(img, axis=0)
|
||||||
grad = self.get_gradients(img)
|
grad = self.get_gradients(img)
|
||||||
grads.append(grad[0])
|
grads.append(grad[0])
|
||||||
|
Loading…
Reference in New Issue
Block a user