fast-mri/src/sfransen/DWI_exp/callbacks.py

96 lines
4.0 KiB
Python
Executable File

import numpy as np
import SimpleITK as sitk
from helpers import *
import tensorflow.keras.backend as K
import tensorflow as tf
from tensorflow.keras.callbacks import Callback
from sklearn.metrics import roc_auc_score, roc_curve
def dice_coef(y_true, y_pred):
y_true_f = K.flatten(y_true)
y_pred_f = K.round(K.flatten(y_pred))
intersection = K.sum(y_true_f * y_pred_f)
return (2. * intersection + K.epsilon()) / (
K.sum(y_true_f) + K.sum(y_pred_f) + K.epsilon())
# def auc_value(y_true, y_pred):
# # print_("start ROC_AUC")
# # y_true_old = self.validation_set[1].squeeze()
# # y_pred_old = np.around(self.model.predict(self.validation_set[0],batch_size=1).squeeze())
# y_true_auc = K.flatten(y_true)
# # print_(f"The shape of y_true = {np.shape(y_true)}" )
# y_pred_auc = K.round(K.flatten(y_pred))
# # print_(f"The shape of y_pred = {np.shape(y_pred)}" )
# # fpr, tpr, thresholds = roc_curve(y_true, y_pred, pos_label=1)
# auc = roc_auc_score(y_true_auc, y_pred_auc)
# print_('AUC:', auc)
# print_('AUC shape:', np.shape(auc))
# # print_(f'ROC: fpr={fpr} , tpr={tpr}, thresholds={thresholds}')
# return auc
class IntermediateImages(Callback):
def __init__(self, validation_set, prefix, sequences,
num_images=10):
self.prefix = prefix
self.num_images = num_images
self.validation_set = (
validation_set[0][:num_images, ...],
validation_set[1][:num_images, ...]
)
# Export scan crops and targets once
# they don't change during training so we export them only once
for i in range(min(self.num_images, self.validation_set[0].shape[0])):
for s_idx, s in enumerate(sequences):
img_s = sitk.GetImageFromArray(
self.validation_set[0][i][..., s_idx].squeeze().T)
sitk.WriteImage(img_s, f"{prefix}_{i:03d}_{s}.nii.gz")
seg_s = sitk.GetImageFromArray(
self.validation_set[1][i].squeeze().T)
sitk.WriteImage(seg_s, f"{prefix}_{i:03d}_seg.nii.gz")
def on_epoch_end(self, epoch, logs={}):
# Predict on the validation_set
predictions = self.model.predict(self.validation_set, batch_size=1)
# print_("start ROC_AUC")
# y_true_old = self.validation_set[1].squeeze()
# y_pred_old = np.around(self.model.predict(self.validation_set[0],batch_size=1).squeeze())
# y_true = np.array(y_true_old).flatten()
# print_(f"The shape of y_true = {np.shape(y_true)}" )
# y_pred = np.array(y_pred_old).flatten()
# print_(f"The shape of y_pred = {np.shape(y_pred)}" )
# fpr, tpr, thresholds = roc_curve(y_true, y_pred, pos_label=1)
# auc = roc_auc_score(y_true, y_pred)
# print_('AUC:', auc)
# print_(f'ROC: fpr={fpr} , tpr={tpr}, thresholds={thresholds}')
for i in range(min(self.num_images, self.validation_set[0].shape[0])):
prd_s = sitk.GetImageFromArray(predictions[i].squeeze().T)
prd_bin_s = sitk.GetImageFromArray(
np.around(predictions[i]).astype(np.float32).squeeze().T)
sitk.WriteImage(prd_s, f"{self.prefix}_{i:03d}_pred.nii.gz")
sitk.WriteImage(prd_bin_s, f"{self.prefix}_{i:03d}_pred_bin.nii.gz")
# class RocCallback(Callback):
# def __init__(self,validation_data):
# # self.x = training_data[0]
# # self.y = training_data[1]
# self.x_val = validation_data[0]
# self.y_val = validation_data[1]
# def on_epoch_end(self, epoch, logs={}):
# # y_pred_train = self.model.predict_proba(self.x)
# # roc_train = roc_auc_score(self.y, y_pred_train)
# y_pred_val = self.model.predict(self.x_val, batch_size=1)
# roc_val = roc_auc_score(self.y_val, y_pred_val)
# print('\rroc-auc_train: %s - roc-auc_val: %s' % (str(round(roc_val,4))),end=100*' '+'\n')
# return