96 lines
4.0 KiB
Python
Executable File
96 lines
4.0 KiB
Python
Executable File
import numpy as np
|
|
import SimpleITK as sitk
|
|
from helpers import *
|
|
import tensorflow.keras.backend as K
|
|
import tensorflow as tf
|
|
from tensorflow.keras.callbacks import Callback
|
|
from sklearn.metrics import roc_auc_score, roc_curve
|
|
|
|
def dice_coef(y_true, y_pred):
|
|
y_true_f = K.flatten(y_true)
|
|
y_pred_f = K.round(K.flatten(y_pred))
|
|
intersection = K.sum(y_true_f * y_pred_f)
|
|
return (2. * intersection + K.epsilon()) / (
|
|
K.sum(y_true_f) + K.sum(y_pred_f) + K.epsilon())
|
|
|
|
|
|
# def auc_value(y_true, y_pred):
|
|
# # print_("start ROC_AUC")
|
|
# # y_true_old = self.validation_set[1].squeeze()
|
|
# # y_pred_old = np.around(self.model.predict(self.validation_set[0],batch_size=1).squeeze())
|
|
|
|
# y_true_auc = K.flatten(y_true)
|
|
# # print_(f"The shape of y_true = {np.shape(y_true)}" )
|
|
|
|
# y_pred_auc = K.round(K.flatten(y_pred))
|
|
# # print_(f"The shape of y_pred = {np.shape(y_pred)}" )
|
|
|
|
# # fpr, tpr, thresholds = roc_curve(y_true, y_pred, pos_label=1)
|
|
# auc = roc_auc_score(y_true_auc, y_pred_auc)
|
|
# print_('AUC:', auc)
|
|
# print_('AUC shape:', np.shape(auc))
|
|
# # print_(f'ROC: fpr={fpr} , tpr={tpr}, thresholds={thresholds}')
|
|
# return auc
|
|
|
|
|
|
class IntermediateImages(Callback):
|
|
def __init__(self, validation_set, prefix, sequences,
|
|
num_images=10):
|
|
self.prefix = prefix
|
|
self.num_images = num_images
|
|
self.validation_set = (
|
|
validation_set[0][:num_images, ...],
|
|
validation_set[1][:num_images, ...]
|
|
)
|
|
|
|
# Export scan crops and targets once
|
|
# they don't change during training so we export them only once
|
|
for i in range(min(self.num_images, self.validation_set[0].shape[0])):
|
|
for s_idx, s in enumerate(sequences):
|
|
img_s = sitk.GetImageFromArray(
|
|
self.validation_set[0][i][..., s_idx].squeeze().T)
|
|
sitk.WriteImage(img_s, f"{prefix}_{i:03d}_{s}.nii.gz")
|
|
seg_s = sitk.GetImageFromArray(
|
|
self.validation_set[1][i].squeeze().T)
|
|
sitk.WriteImage(seg_s, f"{prefix}_{i:03d}_seg.nii.gz")
|
|
|
|
def on_epoch_end(self, epoch, logs={}):
|
|
# Predict on the validation_set
|
|
predictions = self.model.predict(self.validation_set, batch_size=1)
|
|
|
|
# print_("start ROC_AUC")
|
|
# y_true_old = self.validation_set[1].squeeze()
|
|
# y_pred_old = np.around(self.model.predict(self.validation_set[0],batch_size=1).squeeze())
|
|
|
|
# y_true = np.array(y_true_old).flatten()
|
|
# print_(f"The shape of y_true = {np.shape(y_true)}" )
|
|
|
|
# y_pred = np.array(y_pred_old).flatten()
|
|
# print_(f"The shape of y_pred = {np.shape(y_pred)}" )
|
|
|
|
# fpr, tpr, thresholds = roc_curve(y_true, y_pred, pos_label=1)
|
|
# auc = roc_auc_score(y_true, y_pred)
|
|
# print_('AUC:', auc)
|
|
# print_(f'ROC: fpr={fpr} , tpr={tpr}, thresholds={thresholds}')
|
|
|
|
for i in range(min(self.num_images, self.validation_set[0].shape[0])):
|
|
prd_s = sitk.GetImageFromArray(predictions[i].squeeze().T)
|
|
prd_bin_s = sitk.GetImageFromArray(
|
|
np.around(predictions[i]).astype(np.float32).squeeze().T)
|
|
sitk.WriteImage(prd_s, f"{self.prefix}_{i:03d}_pred.nii.gz")
|
|
sitk.WriteImage(prd_bin_s, f"{self.prefix}_{i:03d}_pred_bin.nii.gz")
|
|
|
|
# class RocCallback(Callback):
|
|
# def __init__(self,validation_data):
|
|
# # self.x = training_data[0]
|
|
# # self.y = training_data[1]
|
|
# self.x_val = validation_data[0]
|
|
# self.y_val = validation_data[1]
|
|
|
|
# def on_epoch_end(self, epoch, logs={}):
|
|
# # y_pred_train = self.model.predict_proba(self.x)
|
|
# # roc_train = roc_auc_score(self.y, y_pred_train)
|
|
# y_pred_val = self.model.predict(self.x_val, batch_size=1)
|
|
# roc_val = roc_auc_score(self.y_val, y_pred_val)
|
|
# print('\rroc-auc_train: %s - roc-auc_val: %s' % (str(round(roc_val,4))),end=100*' '+'\n')
|
|
# return |