kleine aanpassingen voor weergave

This commit is contained in:
Stefan 2022-03-23 17:00:22 +01:00
parent 01f458d0db
commit 82c285b1b0
7 changed files with 38 additions and 35 deletions

View File

@ -6,10 +6,11 @@ from datetime import datetime
import sys import sys
# sys.path.append('./../code') # sys.path.append('./../code')
# from utils_quintin import * # from utils_quintin import *
# from sfransen.utils_quintin import * from sfransen.utils_quintin import *
# sys.path.append('./../code/DWI_exp') # sys.path.append('./../code/DWI_exp')
# from callbacks import IntermediateImages, dice_coef # from callbacks import IntermediateImages, dice_coef
# from callbacks import RocCallback # from callbacks import RocCallback
from sfransen.utils_quintin import *
from sfransen.DWI_exp import IntermediateImages, dice_coef from sfransen.DWI_exp import IntermediateImages, dice_coef
from sfransen.DWI_exp.preprocessing_function import preprocess from sfransen.DWI_exp.preprocessing_function import preprocess
import yaml import yaml

View File

@ -35,6 +35,7 @@ EXPERIMENT = args.experiment
MODEL_PATH = f'./../train_output/{EXPERIMENT}_{series_}/models/{EXPERIMENT}_{series_}.h5' MODEL_PATH = f'./../train_output/{EXPERIMENT}_{series_}/models/{EXPERIMENT}_{series_}.h5'
YAML_DIR = f'./../train_output/{EXPERIMENT}_{series_}' YAML_DIR = f'./../train_output/{EXPERIMENT}_{series_}'
IMAGE_DIR = f'./../train_output/{EXPERIMENT}_{series_}'
DATA_DIR = "./../data/Nijmegen paths/" DATA_DIR = "./../data/Nijmegen paths/"
TARGET_SPACING = (0.5, 0.5, 3) TARGET_SPACING = (0.5, 0.5, 3)
@ -42,10 +43,11 @@ INPUT_SHAPE = (192, 192, 24, len(SERIES))
IMAGE_SHAPE = INPUT_SHAPE[:3] IMAGE_SHAPE = INPUT_SHAPE[:3]
DATA_SPLIT_INDEX = read_yaml_to_dict('./../data/Nijmegen paths/train_val_test_idxs.yml') DATA_SPLIT_INDEX = read_yaml_to_dict('./../data/Nijmegen paths/train_val_test_idxs.yml')
TEST_INDEX = DATA_SPLIT_INDEX['test_set0'] TEST_INDEX = DATA_SPLIT_INDEX['val_set0']
N_CPUS = 12 N_CPUS = 12
########## load images in parrallel ############## ########## load images in parrallel ##############
print_(f"> Loading images into RAM...") print_(f"> Loading images into RAM...")
@ -116,8 +118,8 @@ predictions = [preprocess_softmax(pred, threshold="dynamic")[0] for pred in pred
# Remove outer edges # Remove outer edges
zeros = np.zeros(np.shape(predictions)) zeros = np.zeros(np.shape(predictions))
test = np.squeeze(predictions)[:,:,2:190,2:190] test = np.squeeze(predictions)[:,2:-2,2:190,2:190]
zeros[:,:,2:190,2:190] = test zeros[:,2:-2,2:190,2:190] = test
predictions = zeros predictions = zeros
# perform Froc # perform Froc
@ -127,7 +129,6 @@ dump_dict_to_yaml(metrics, YAML_DIR, "froc_metrics", verbose=True)
############## save image as example ################# ############## save image as example #################
# save image nmr 3 # save image nmr 3
IMAGE_DIR = f'./../train_output/train_10h_{series_}'
img_s = sitk.GetImageFromArray(predictions_blur[3].squeeze()) img_s = sitk.GetImageFromArray(predictions_blur[3].squeeze())
sitk.WriteImage(img_s, f"{IMAGE_DIR}/predictions_blur_001.nii.gz") sitk.WriteImage(img_s, f"{IMAGE_DIR}/predictions_blur_001.nii.gz")

View File

@ -1,6 +1,4 @@
import sys from sfransen.utils_quintin import *
sys.path.append('./../code')
from utils_quintin import *
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import argparse import argparse

View File

@ -44,10 +44,10 @@ froc_metrics = read_yaml_to_dict(f'{YAML_DIR}/froc_metrics.yml')
top_10_idx = np.argsort(froc_metrics['roc_pred'])[-1 :] top_10_idx = np.argsort(froc_metrics['roc_pred'])[-1 :]
DATA_SPLIT_INDEX = read_yaml_to_dict('./../data/Nijmegen paths/train_val_test_idxs.yml') DATA_SPLIT_INDEX = read_yaml_to_dict('./../data/Nijmegen paths/train_val_test_idxs.yml')
TEST_INDEX = DATA_SPLIT_INDEX['test_set0'] TEST_INDEX = DATA_SPLIT_INDEX['val_set0']
TEST_INDEX_top10 = [TEST_INDEX[i] for i in top_10_idx] TEST_INDEX_top10 = [TEST_INDEX[i] for i in top_10_idx]
TEST_INDEX_image = [371]
N_CPUS = 12 N_CPUS = 12
########## load images in parrallel ############## ########## load images in parrallel ##############
@ -77,7 +77,7 @@ partial_seg = partial(load_images_parrallel,
images = [] images = []
for s in SERIES: for s in SERIES:
image_paths_seq = image_paths[s] image_paths_seq = image_paths[s]
image_paths_index = np.asarray(image_paths_seq)[TEST_INDEX_top10] image_paths_index = np.asarray(image_paths_seq)[TEST_INDEX_image]
data_list = pool.map(partial_images,image_paths_index) data_list = pool.map(partial_images,image_paths_index)
data = np.stack(data_list, axis=0) data = np.stack(data_list, axis=0)
images.append(data) images.append(data)
@ -85,7 +85,7 @@ for s in SERIES:
images_list = np.transpose(images, (1, 2, 3, 4, 0)) images_list = np.transpose(images, (1, 2, 3, 4, 0))
#load segmentations #load segmentations
seg_paths_index = np.asarray(seg_paths)[TEST_INDEX_top10] seg_paths_index = np.asarray(seg_paths)[TEST_INDEX_image]
data_list = pool.map(partial_seg,seg_paths_index) data_list = pool.map(partial_seg,seg_paths_index)
segmentations = np.stack(data_list, axis=0) segmentations = np.stack(data_list, axis=0)

View File

@ -1,7 +1,6 @@
import argparse import argparse
import numpy as np import numpy as np
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
# import matplotlib.cm as cm
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description='Calculate the froc metrics and store in froc_metrics.yml') description='Calculate the froc metrics and store in froc_metrics.yml')
@ -19,6 +18,7 @@ EXPERIMENT = args.experiment
SALIENCY_DIR = f'./../train_output/{EXPERIMENT}_{series_}/saliency.npy' SALIENCY_DIR = f'./../train_output/{EXPERIMENT}_{series_}/saliency.npy'
IMAGES_DIR = f'./../train_output/{EXPERIMENT}_{series_}/images_list.npy' IMAGES_DIR = f'./../train_output/{EXPERIMENT}_{series_}/images_list.npy'
SEGMENTATION_DIR = f'./../train_output/{EXPERIMENT}_{series_}/segmentations.npy' SEGMENTATION_DIR = f'./../train_output/{EXPERIMENT}_{series_}/segmentations.npy'
SLIDE = 10
########## load saliency map ############ ########## load saliency map ############
heatmap = np.load(SALIENCY_DIR) heatmap = np.load(SALIENCY_DIR)
@ -38,13 +38,13 @@ fig, axes = plt.subplots(2,len(SERIES))
print(np.shape(axes)) print(np.shape(axes))
print(np.shape(heatmap)) print(np.shape(heatmap))
print(np.shape(images_list)) print(np.shape(images_list))
max_value = np.amax(heatmap) max_value = np.amax(heatmap[:,:,SLIDE,:])
min_value = np.amin(heatmap) min_value = np.amin(heatmap[:,:,SLIDE,:])
for indx in range(len(SERIES)): for indx in range(len(SERIES)):
print(indx) print(indx)
axes[0,indx].imshow(images_list[:,:,12,indx],cmap='gray') axes[0,indx].imshow(np.transpose(images_list[:,:,SLIDE,indx]),cmap='gray')
im = axes[1,indx].imshow(np.squeeze(heatmap[:,:,12,indx]),vmin=min_value, vmax=max_value) im = axes[1,indx].imshow(np.transpose(np.squeeze(heatmap[:,:,SLIDE,indx])),vmin=min_value, vmax=max_value)
axes[0,indx].set_title(SERIES[indx]) axes[0,indx].set_title(SERIES[indx])
axes[0,indx].set_axis_off() axes[0,indx].set_axis_off()
axes[1,indx].set_axis_off() axes[1,indx].set_axis_off()

View File

@ -28,10 +28,10 @@ df = pd.read_csv(f'{folder_input}')
# read csv file # read csv file
for metric in df: for metric in df:
# if not metric == 'epoch': if not metric == 'epoch':
if metric == 'loss' or metric == 'val_loss': # if metric == 'loss' or metric == 'val_loss':
plt.plot(df['epoch'], df[metric], label=metric) plt.plot(df['epoch'], df[metric], label=metric)
plt.ylim(ymin=0,ymax=0.01) # plt.ylim(ymin=0,ymax=0.01)

View File

@ -1,13 +1,16 @@
quit()
#take one image out
heatmap = np.squeeze(heatmap[0])
import numpy as np import numpy as np
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
# Fixing random state for reproducibility ####### dir ############
np.random.seed(19680801) SALIENCY_DIR = f'./../train_output/train_n0.001_run4_t2_b50_b400_b800_b1400_adc/saliency.npy'
########## load saliency map ############
heatmap = np.load(SALIENCY_DIR)
heatmap = np.squeeze(heatmap)
#take one image out
heatmap = np.squeeze(abs(heatmap))
max_value = np.amax(heatmap)
min_value = np.amin(heatmap)
class IndexTracker: class IndexTracker:
@ -19,7 +22,7 @@ class IndexTracker:
rows, cols, self.slices = X.shape rows, cols, self.slices = X.shape
self.ind = self.slices//2 self.ind = self.slices//2
self.im = ax.imshow(self.X[:, :, self.ind], cmap='jet') self.im = ax.imshow(self.X[:, :, self.ind], cmap='jet',vmin=min_value, vmax=max_value)
self.update() self.update()
def on_scroll(self, event): def on_scroll(self, event):
@ -35,14 +38,14 @@ class IndexTracker:
self.ax.set_ylabel('slice %s' % self.ind) self.ax.set_ylabel('slice %s' % self.ind)
self.im.axes.figure.canvas.draw() self.im.axes.figure.canvas.draw()
plt.figure(0) # plt.figure()
fig, ax = plt.subplots(1, 1) # fig, ax = plt.subplots(1, 1)
tracker = IndexTracker(ax, heatmap[:,:,:,5]) # tracker = IndexTracker(ax, heatmap[:,:,:,0])
fig.canvas.mpl_connect('scroll_event', tracker.on_scroll) # fig.canvas.mpl_connect('scroll_event', tracker.on_scroll)
plt.show() # plt.show()
plt.figure(1) plt.figure()
fig, ax = plt.subplots(1, 1) fig, ax = plt.subplots(1, 1)
tracker = IndexTracker(ax, heatmap[:,:,:,3]) tracker = IndexTracker(ax, heatmap[:,:,:,4])
fig.canvas.mpl_connect('scroll_event', tracker.on_scroll) fig.canvas.mpl_connect('scroll_event', tracker.on_scroll)
plt.show() plt.show()