kleine aanpassingen voor weergave

This commit is contained in:
Stefan 2022-03-23 17:00:22 +01:00
parent 01f458d0db
commit 82c285b1b0
7 changed files with 38 additions and 35 deletions

View File

@ -6,10 +6,11 @@ from datetime import datetime
import sys
# sys.path.append('./../code')
# from utils_quintin import *
# from sfransen.utils_quintin import *
from sfransen.utils_quintin import *
# sys.path.append('./../code/DWI_exp')
# from callbacks import IntermediateImages, dice_coef
# from callbacks import RocCallback
from sfransen.utils_quintin import *
from sfransen.DWI_exp import IntermediateImages, dice_coef
from sfransen.DWI_exp.preprocessing_function import preprocess
import yaml

View File

@ -35,6 +35,7 @@ EXPERIMENT = args.experiment
MODEL_PATH = f'./../train_output/{EXPERIMENT}_{series_}/models/{EXPERIMENT}_{series_}.h5'
YAML_DIR = f'./../train_output/{EXPERIMENT}_{series_}'
IMAGE_DIR = f'./../train_output/{EXPERIMENT}_{series_}'
DATA_DIR = "./../data/Nijmegen paths/"
TARGET_SPACING = (0.5, 0.5, 3)
@ -42,10 +43,11 @@ INPUT_SHAPE = (192, 192, 24, len(SERIES))
IMAGE_SHAPE = INPUT_SHAPE[:3]
DATA_SPLIT_INDEX = read_yaml_to_dict('./../data/Nijmegen paths/train_val_test_idxs.yml')
TEST_INDEX = DATA_SPLIT_INDEX['test_set0']
TEST_INDEX = DATA_SPLIT_INDEX['val_set0']
N_CPUS = 12
########## load images in parrallel ##############
print_(f"> Loading images into RAM...")
@ -116,8 +118,8 @@ predictions = [preprocess_softmax(pred, threshold="dynamic")[0] for pred in pred
# Remove outer edges
zeros = np.zeros(np.shape(predictions))
test = np.squeeze(predictions)[:,:,2:190,2:190]
zeros[:,:,2:190,2:190] = test
test = np.squeeze(predictions)[:,2:-2,2:190,2:190]
zeros[:,2:-2,2:190,2:190] = test
predictions = zeros
# perform Froc
@ -127,7 +129,6 @@ dump_dict_to_yaml(metrics, YAML_DIR, "froc_metrics", verbose=True)
############## save image as example #################
# save image nmr 3
IMAGE_DIR = f'./../train_output/train_10h_{series_}'
img_s = sitk.GetImageFromArray(predictions_blur[3].squeeze())
sitk.WriteImage(img_s, f"{IMAGE_DIR}/predictions_blur_001.nii.gz")

View File

@ -1,6 +1,4 @@
import sys
sys.path.append('./../code')
from utils_quintin import *
from sfransen.utils_quintin import *
import matplotlib.pyplot as plt
import argparse

View File

@ -44,10 +44,10 @@ froc_metrics = read_yaml_to_dict(f'{YAML_DIR}/froc_metrics.yml')
top_10_idx = np.argsort(froc_metrics['roc_pred'])[-1 :]
DATA_SPLIT_INDEX = read_yaml_to_dict('./../data/Nijmegen paths/train_val_test_idxs.yml')
TEST_INDEX = DATA_SPLIT_INDEX['test_set0']
TEST_INDEX = DATA_SPLIT_INDEX['val_set0']
TEST_INDEX_top10 = [TEST_INDEX[i] for i in top_10_idx]
TEST_INDEX_image = [371]
N_CPUS = 12
########## load images in parrallel ##############
@ -77,7 +77,7 @@ partial_seg = partial(load_images_parrallel,
images = []
for s in SERIES:
image_paths_seq = image_paths[s]
image_paths_index = np.asarray(image_paths_seq)[TEST_INDEX_top10]
image_paths_index = np.asarray(image_paths_seq)[TEST_INDEX_image]
data_list = pool.map(partial_images,image_paths_index)
data = np.stack(data_list, axis=0)
images.append(data)
@ -85,7 +85,7 @@ for s in SERIES:
images_list = np.transpose(images, (1, 2, 3, 4, 0))
#load segmentations
seg_paths_index = np.asarray(seg_paths)[TEST_INDEX_top10]
seg_paths_index = np.asarray(seg_paths)[TEST_INDEX_image]
data_list = pool.map(partial_seg,seg_paths_index)
segmentations = np.stack(data_list, axis=0)

View File

@ -1,7 +1,6 @@
import argparse
import numpy as np
import matplotlib.pyplot as plt
# import matplotlib.cm as cm
parser = argparse.ArgumentParser(
description='Calculate the froc metrics and store in froc_metrics.yml')
@ -19,6 +18,7 @@ EXPERIMENT = args.experiment
SALIENCY_DIR = f'./../train_output/{EXPERIMENT}_{series_}/saliency.npy'
IMAGES_DIR = f'./../train_output/{EXPERIMENT}_{series_}/images_list.npy'
SEGMENTATION_DIR = f'./../train_output/{EXPERIMENT}_{series_}/segmentations.npy'
SLIDE = 10
########## load saliency map ############
heatmap = np.load(SALIENCY_DIR)
@ -38,13 +38,13 @@ fig, axes = plt.subplots(2,len(SERIES))
print(np.shape(axes))
print(np.shape(heatmap))
print(np.shape(images_list))
max_value = np.amax(heatmap)
min_value = np.amin(heatmap)
max_value = np.amax(heatmap[:,:,SLIDE,:])
min_value = np.amin(heatmap[:,:,SLIDE,:])
for indx in range(len(SERIES)):
print(indx)
axes[0,indx].imshow(images_list[:,:,12,indx],cmap='gray')
im = axes[1,indx].imshow(np.squeeze(heatmap[:,:,12,indx]),vmin=min_value, vmax=max_value)
axes[0,indx].imshow(np.transpose(images_list[:,:,SLIDE,indx]),cmap='gray')
im = axes[1,indx].imshow(np.transpose(np.squeeze(heatmap[:,:,SLIDE,indx])),vmin=min_value, vmax=max_value)
axes[0,indx].set_title(SERIES[indx])
axes[0,indx].set_axis_off()
axes[1,indx].set_axis_off()

View File

@ -28,10 +28,10 @@ df = pd.read_csv(f'{folder_input}')
# read csv file
for metric in df:
# if not metric == 'epoch':
if metric == 'loss' or metric == 'val_loss':
if not metric == 'epoch':
# if metric == 'loss' or metric == 'val_loss':
plt.plot(df['epoch'], df[metric], label=metric)
plt.ylim(ymin=0,ymax=0.01)
# plt.ylim(ymin=0,ymax=0.01)

View File

@ -1,13 +1,16 @@
quit()
#take one image out
heatmap = np.squeeze(heatmap[0])
import numpy as np
import matplotlib.pyplot as plt
# Fixing random state for reproducibility
np.random.seed(19680801)
####### dir ############
SALIENCY_DIR = f'./../train_output/train_n0.001_run4_t2_b50_b400_b800_b1400_adc/saliency.npy'
########## load saliency map ############
heatmap = np.load(SALIENCY_DIR)
heatmap = np.squeeze(heatmap)
#take one image out
heatmap = np.squeeze(abs(heatmap))
max_value = np.amax(heatmap)
min_value = np.amin(heatmap)
class IndexTracker:
@ -19,7 +22,7 @@ class IndexTracker:
rows, cols, self.slices = X.shape
self.ind = self.slices//2
self.im = ax.imshow(self.X[:, :, self.ind], cmap='jet')
self.im = ax.imshow(self.X[:, :, self.ind], cmap='jet',vmin=min_value, vmax=max_value)
self.update()
def on_scroll(self, event):
@ -35,14 +38,14 @@ class IndexTracker:
self.ax.set_ylabel('slice %s' % self.ind)
self.im.axes.figure.canvas.draw()
plt.figure(0)
fig, ax = plt.subplots(1, 1)
tracker = IndexTracker(ax, heatmap[:,:,:,5])
fig.canvas.mpl_connect('scroll_event', tracker.on_scroll)
plt.show()
# plt.figure()
# fig, ax = plt.subplots(1, 1)
# tracker = IndexTracker(ax, heatmap[:,:,:,0])
# fig.canvas.mpl_connect('scroll_event', tracker.on_scroll)
# plt.show()
plt.figure(1)
plt.figure()
fig, ax = plt.subplots(1, 1)
tracker = IndexTracker(ax, heatmap[:,:,:,3])
tracker = IndexTracker(ax, heatmap[:,:,:,4])
fig.canvas.mpl_connect('scroll_event', tracker.on_scroll)
plt.show()