working froc, changes in preprocessing
This commit is contained in:
@@ -3,3 +3,4 @@ from .callbacks import *
|
||||
from .helpers import *
|
||||
from .preprocessing_function import *
|
||||
from .unet import *
|
||||
from .losses import *
|
||||
|
78
src/sfransen/DWI_exp/losses.py
Executable file
78
src/sfransen/DWI_exp/losses.py
Executable file
@@ -0,0 +1,78 @@
|
||||
import tensorflow as tf
|
||||
from tensorflow.image import ssim
|
||||
from tensorflow.keras import backend as K
|
||||
|
||||
def weighted_binary_cross_entropy(weights: dict, from_logits: bool = False):
|
||||
'''
|
||||
Return a function for calculating weighted binary cross entropy
|
||||
It should be used for multi-hot encoded labels
|
||||
|
||||
# Example
|
||||
y_true = tf.convert_to_tensor([1, 0, 0, 0, 0, 0], dtype=tf.int64)
|
||||
y_pred = tf.convert_to_tensor([0.6, 0.1, 0.1, 0.9, 0.1, 0.], dtype=tf.float32)
|
||||
weights = {
|
||||
0: 1.,
|
||||
1: 2.
|
||||
}
|
||||
# with weights
|
||||
loss_fn = get_loss_for_multilabels(weights=weights, from_logits=False)
|
||||
loss = loss_fn(y_true, y_pred)
|
||||
print(loss)
|
||||
# tf.Tensor(0.6067193, shape=(), dtype=float32)
|
||||
|
||||
# without weights
|
||||
loss_fn = get_loss_for_multilabels()
|
||||
loss = loss_fn(y_true, y_pred)
|
||||
print(loss)
|
||||
# tf.Tensor(0.52158177, shape=(), dtype=float32)
|
||||
|
||||
# Another example
|
||||
y_true = tf.convert_to_tensor([[0., 1.], [0., 0.]], dtype=tf.float32)
|
||||
y_pred = tf.convert_to_tensor([[0.6, 0.4], [0.4, 0.6]], dtype=tf.float32)
|
||||
weights = {
|
||||
0: 1.,
|
||||
1: 2.
|
||||
}
|
||||
# with weights
|
||||
loss_fn = get_loss_for_multilabels(weights=weights, from_logits=False)
|
||||
loss = loss_fn(y_true, y_pred)
|
||||
print(loss)
|
||||
# tf.Tensor(1.0439969, shape=(), dtype=float32)
|
||||
|
||||
# without weights
|
||||
loss_fn = get_loss_for_multilabels()
|
||||
loss = loss_fn(y_true, y_pred)
|
||||
print(loss)
|
||||
# tf.Tensor(0.81492424, shape=(), dtype=float32)
|
||||
|
||||
@param weights A dict setting weights for 0 and 1 label. e.g.
|
||||
{
|
||||
0: 1.
|
||||
1: 8.
|
||||
}
|
||||
For this case, we want to emphasise those true (1) label,
|
||||
because we have many false (0) label. e.g.
|
||||
[
|
||||
[0 1 0 0 0 0 0 0 0 1]
|
||||
[0 0 0 0 1 0 0 0 0 0]
|
||||
[0 0 0 0 1 0 0 0 0 0]
|
||||
]
|
||||
|
||||
|
||||
|
||||
@param from_logits If False, we apply sigmoid to each logit
|
||||
@return A function to calcualte (weighted) binary cross entropy
|
||||
'''
|
||||
assert 0 in weights
|
||||
assert 1 in weights
|
||||
|
||||
def weighted_cross_entropy_fn(y_true, y_pred):
|
||||
tf_y_true = tf.cast(y_true, dtype=y_pred.dtype)
|
||||
tf_y_pred = tf.cast(y_pred, dtype=y_pred.dtype)
|
||||
|
||||
weights_v = tf.where(tf.equal(tf_y_true, 1), weights[1], weights[0])
|
||||
ce = K.binary_crossentropy(tf_y_true, tf_y_pred, from_logits=from_logits)
|
||||
loss = K.mean(tf.multiply(ce, weights_v))
|
||||
return loss
|
||||
|
||||
return weighted_cross_entropy_fn
|
@@ -187,7 +187,7 @@ def preprocess_softmax(softmax: np.ndarray,
|
||||
def evaluate(
|
||||
y_true: np.ndarray,
|
||||
y_pred: np.ndarray,
|
||||
min_overlap=0.02,
|
||||
min_overlap=0.10,
|
||||
overlap_func: str = 'DSC',
|
||||
case_confidence: str = 'max',
|
||||
multiple_lesion_candidates_selection_criteria='overlap',
|
||||
@@ -382,7 +382,8 @@ def counts_from_lesion_evaluations(
|
||||
TP[i] = tp
|
||||
else:
|
||||
# extend FROC curve to infinity
|
||||
TP[i] = TP[-2]
|
||||
TP[i] = TP[-2] #note: aangepast stefan 11-04-2022
|
||||
# TP[i] = TP[-1]
|
||||
FP[i] = np.inf
|
||||
|
||||
return TP, FP, thresholds, num_lesions
|
||||
|
@@ -1,6 +1,7 @@
|
||||
from typing import List
|
||||
import SimpleITK as sitk
|
||||
from sfransen.DWI_exp.helpers import *
|
||||
import numpy as np
|
||||
|
||||
def load_images_parrallel(
|
||||
image_paths: str,
|
||||
@@ -12,19 +13,31 @@ def load_images_parrallel(
|
||||
|
||||
#resample
|
||||
mri_tra_s = resample(img_s,
|
||||
min_shape=target_shape,
|
||||
min_shape=(s+1 for s in target_shape),
|
||||
method=sitk.sitkNearestNeighbor,
|
||||
new_spacing=target_space)
|
||||
|
||||
#center crop
|
||||
mri_tra_s = center_crop(mri_tra_s, shape=target_shape)
|
||||
#normalize
|
||||
if seq != 'seg':
|
||||
filter = sitk.NormalizeImageFilter()
|
||||
mri_tra_s = filter.Execute(mri_tra_s)
|
||||
else:
|
||||
filter = sitk.BinaryThresholdImageFilter()
|
||||
filter.SetLowerThreshold(1.0)
|
||||
mri_tra_s = filter.Execute(mri_tra_s)
|
||||
# if seq != 'seg':
|
||||
# filter = sitk.NormalizeImageFilter()
|
||||
# mri_tra_s = filter.Execute(mri_tra_s)
|
||||
# else:
|
||||
# filter = sitk.BinaryThresholdImageFilter()
|
||||
# filter.SetLowerThreshold(1.0)
|
||||
# mri_tra_s = filter.Execute(mri_tra_s)
|
||||
|
||||
return sitk.GetArrayFromImage(mri_tra_s).T
|
||||
# return sitk.GetArrayFromImage(mri_tra_s).T
|
||||
|
||||
# Return sitk.Image instead of numpy np.ndarray.
|
||||
|
||||
### method trained in Unet
|
||||
img_n = sitk.GetArrayFromImage(mri_tra_s).T
|
||||
|
||||
if seq != 'seg':
|
||||
image_return = (img_n - np.mean(img_n)) / ( 2* np.std(img_n))
|
||||
else:
|
||||
image_return = np.clip(img_n, 0., 1.)
|
||||
|
||||
return image_return
|
||||
|
Reference in New Issue
Block a user