commit before migration to habrok

This commit is contained in:
Stefan
2023-03-29 13:05:32 +02:00
parent 9468dadfa3
commit ad595bb25e
24 changed files with 4098 additions and 1 deletions

6
code/DWI_exp/__init__.py Executable file
View File

@@ -0,0 +1,6 @@
print('verkeerde code')
from .batchgenerator import *
from .callbacks import *
from .helpers import *
from .preprocessing_function import *
from .unet import *

64
code/DWI_exp/batchgenerator.py Executable file
View File

@@ -0,0 +1,64 @@
import math
from typing import Callable, Dict, List, Optional, Tuple
import numpy as np
from tensorflow.keras.utils import Sequence
class BatchGenerator(Sequence):
def __init__(self,
images: Dict[str, List[np.ndarray]],
segmentations: List[np.ndarray],
sequences: List[str],
shape: Tuple[int],
indexes: List[int] = None,
batch_size: int = 5,
shuffle: bool = False,
augmentation_function: Optional[Callable\
[[np.ndarray, np.ndarray], Tuple[np.ndarray, np.ndarray]]] = None):
self.images = images
self.segmentations = segmentations
self.sequences = sequences
self.indexes = indexes
self.batch_size = batch_size
self.shuffle = shuffle
self.augmentation = augmentation_function
self.num_rows = len(images)
if self.indexes == None:
self.indexes = list(range(self.num_rows))
if type(indexes) == int:
self.indexes = list(range(self.indexes))
if self.batch_size == None:
self.batch_size = len(self.indexes)
# Stack image index as index 0, stack sequences as index -1
self._images = np.stack(
[np.stack(self.images[k], axis=0) for k in sequences], axis=-1)
self._segmentations = np.stack(segmentations, axis=0)[..., np.newaxis]
def __len__(self):
return math.ceil(len(self.indexes) / self.batch_size)
def __getitem__(self, batch_idx):
# Get image / segmentation indexes for current batch
indexes = self.indexes[batch_idx*self.batch_size:(1+batch_idx)*self.batch_size]
# Get matching images and segmentations
images = self._images[indexes]
segmentations = self._segmentations[indexes]
# Apply data augmentation if a function is provided
if self.augmentation:
for img_idx in range(len(indexes)):
images[img_idx], segmentations[img_idx] = self.augmentation(
images[img_idx], segmentations[img_idx])
return images, segmentations
def on_epoch_end(self):
super().on_epoch_end()
if self.shuffle:
np.random.shuffle(self.indexes)

96
code/DWI_exp/callbacks.py Executable file
View File

@@ -0,0 +1,96 @@
import numpy as np
import SimpleITK as sitk
from helpers import *
import tensorflow.keras.backend as K
import tensorflow as tf
from tensorflow.keras.callbacks import Callback
from sklearn.metrics import roc_auc_score, roc_curve
def dice_coef(y_true, y_pred):
y_true_f = K.flatten(y_true)
y_pred_f = K.round(K.flatten(y_pred))
intersection = K.sum(y_true_f * y_pred_f)
return (2. * intersection + K.epsilon()) / (
K.sum(y_true_f) + K.sum(y_pred_f) + K.epsilon())
# def auc_value(y_true, y_pred):
# # print_("start ROC_AUC")
# # y_true_old = self.validation_set[1].squeeze()
# # y_pred_old = np.around(self.model.predict(self.validation_set[0],batch_size=1).squeeze())
# y_true_auc = K.flatten(y_true)
# # print_(f"The shape of y_true = {np.shape(y_true)}" )
# y_pred_auc = K.round(K.flatten(y_pred))
# # print_(f"The shape of y_pred = {np.shape(y_pred)}" )
# # fpr, tpr, thresholds = roc_curve(y_true, y_pred, pos_label=1)
# auc = roc_auc_score(y_true_auc, y_pred_auc)
# print_('AUC:', auc)
# print_('AUC shape:', np.shape(auc))
# # print_(f'ROC: fpr={fpr} , tpr={tpr}, thresholds={thresholds}')
# return auc
class IntermediateImages(Callback):
def __init__(self, validation_set, prefix, sequences,
num_images=10):
self.prefix = prefix
self.num_images = num_images
self.validation_set = (
validation_set[0][:num_images, ...],
validation_set[1][:num_images, ...]
)
# Export scan crops and targets once
# they don't change during training so we export them only once
for i in range(min(self.num_images, self.validation_set[0].shape[0])):
for s_idx, s in enumerate(sequences):
img_s = sitk.GetImageFromArray(
self.validation_set[0][i][..., s_idx].squeeze().T)
sitk.WriteImage(img_s, f"{prefix}_{i:03d}_{s}.nii.gz")
seg_s = sitk.GetImageFromArray(
self.validation_set[1][i].squeeze().T)
sitk.WriteImage(seg_s, f"{prefix}_{i:03d}_seg.nii.gz")
def on_epoch_end(self, epoch, logs={}):
# Predict on the validation_set
predictions = self.model.predict(self.validation_set, batch_size=1)
# print_("start ROC_AUC")
# y_true_old = self.validation_set[1].squeeze()
# y_pred_old = np.around(self.model.predict(self.validation_set[0],batch_size=1).squeeze())
# y_true = np.array(y_true_old).flatten()
# print_(f"The shape of y_true = {np.shape(y_true)}" )
# y_pred = np.array(y_pred_old).flatten()
# print_(f"The shape of y_pred = {np.shape(y_pred)}" )
# fpr, tpr, thresholds = roc_curve(y_true, y_pred, pos_label=1)
# auc = roc_auc_score(y_true, y_pred)
# print_('AUC:', auc)
# print_(f'ROC: fpr={fpr} , tpr={tpr}, thresholds={thresholds}')
for i in range(min(self.num_images, self.validation_set[0].shape[0])):
prd_s = sitk.GetImageFromArray(predictions[i].squeeze().T)
prd_bin_s = sitk.GetImageFromArray(
np.around(predictions[i]).astype(np.float32).squeeze().T)
sitk.WriteImage(prd_s, f"{self.prefix}_{i:03d}_pred.nii.gz")
sitk.WriteImage(prd_bin_s, f"{self.prefix}_{i:03d}_pred_bin.nii.gz")
# class RocCallback(Callback):
# def __init__(self,validation_data):
# # self.x = training_data[0]
# # self.y = training_data[1]
# self.x_val = validation_data[0]
# self.y_val = validation_data[1]
# def on_epoch_end(self, epoch, logs={}):
# # y_pred_train = self.model.predict_proba(self.x)
# # roc_train = roc_auc_score(self.y, y_pred_train)
# y_pred_val = self.model.predict(self.x_val, batch_size=1)
# roc_val = roc_auc_score(self.y_val, y_pred_val)
# print('\rroc-auc_train: %s - roc-auc_val: %s' % (str(round(roc_val,4))),end=100*' '+'\n')
# return

239
code/DWI_exp/helpers.py Executable file
View File

@@ -0,0 +1,239 @@
import os
from os import path
import numpy as np
import SimpleITK as sitk
from scipy import ndimage
import yaml
def print_(*args, **kwargs):
print(*args, **kwargs, flush=True)
def print_config(config, depth=0):
for k in config:
print(end=(" "*depth*2) + f"- {k}:")
if type(config[k]) == dict:
print()
print_config(config[k], depth+1)
else:
print(f"\t{config[k]}")
print_(end="")
def prepare_project_dir(project_dir):
"""
Prepares the work directory for a project
TODO: Create text file containing configuration
"""
def _makedir(n):
os.makedirs(path.join(project_dir, n), exist_ok=True)
_makedir("models")
_makedir("logs")
_makedir("data")
_makedir("docs")
_makedir("samples")
_makedir("output")
def fake_loss(*args, **kwargs):
return 0.
def center_crop(
image: sitk.Image, shape: list, offset: list = None
) -> sitk.Image:
"""Extracts a region of interest from the center of an SITK image.
Parameters:
image: Input image (SITK).
shape: The shape of the
Returns: Cropped image (SITK)
"""
size = image.GetSize()
# Determine the centroid
centroid = [sz / 2 for sz in size]
# Determine the origin for the bounding box by subtracting half the
# shape of the bounding box from each dimension of the centroid.
box_start = [int(c - sh / 2) for c, sh in zip(centroid, shape)]
if offset:
box_start = [b - o for b, o in zip(box_start, offset)]
# Extract the region of provided shape starting from the previously
# determined starting pixel.
region_extractor = sitk.RegionOfInterestImageFilter()
region_extractor.SetSize(shape)
region_extractor.SetIndex(box_start)
cropped_image = region_extractor.Execute(image)
return cropped_image
def augment(img, seg,
noise_chance = 0.3,
noise_mult_max = 0.001,
rotate_chance = 0.2,
rotate_max_angle = 30,
):
#noise_mult_max was initially 0.1
if np.random.uniform() < noise_chance:
img = augment_noise(img, np.random.uniform(0., noise_mult_max))
if np.random.uniform() < rotate_chance:
img, seg = augment_rotate(img, seg,
angle=np.random.uniform(0-rotate_max_angle, rotate_max_angle))
img, seg = augment_tilt(img, seg,
angle=np.random.uniform(0-rotate_max_angle, rotate_max_angle))
return img, seg
def augment_noise(img, multiplier):
noise = np.random.standard_normal(img.shape) * multiplier
return img + noise
def augment_rotate(img, seg, angle):
img = ndimage.rotate(img, angle, reshape=False, cval=0)
seg = ndimage.rotate(seg, angle, reshape=False, order=0)
return img, seg
def augment_tilt(img, seg, angle):
img = ndimage.rotate(img, angle, reshape=False, axes=(2,1), cval=0)
seg = ndimage.rotate(seg, angle, reshape=False, axes=(2,1), order=0)
return img, seg
def resample(
image: sitk.Image, min_shape: list, method=sitk.sitkLinear,
new_spacing: list=[1, 1, 3.6]
) -> sitk.Image:
"""Resamples an image to given target spacing and shape.
Parameters:
image: Input image (SITK).
shape: Minimum output shape for the underlying array.
method: SimpleITK interpolator to use for resampling.
(e.g. sitk.sitkNearestNeighbor, sitk.sitkLinear)
new_spacing: The new spacing to resample to.
Returns:
int: Resampled image
"""
# Extract size and spacing from the image
size = image.GetSize()
spacing = image.GetSpacing()
# Determine how much larger the image will become with the new spacing
factor = [sp / new_sp for sp, new_sp in zip(spacing, new_spacing)]
# Determine the outcome size of the image for each dimension
get_size = lambda size, factor, min_shape: max(int(size * factor), min_shape)
new_size = [get_size(sz, f, sh) for sz, f, sh in zip(size, factor, min_shape)]
# Resample the image
resampler = sitk.ResampleImageFilter()
resampler.SetReferenceImage(image)
resampler.SetOutputSpacing(new_spacing)
resampler.SetSize(new_size)
resampler.SetInterpolator(method)
resampled_image = resampler.Execute(image)
return resampled_image
# Start training
def get_generator(
images: dict,
segmentations: list,
sequences: list,
shape: tuple,
indexes: list = None,
batch_size: int = 5,
shuffle: bool = False,
augmentation = True
):
"""
Returns a (training) generator for use with model.fit().
Parameters:
input_modalities: List of modalty names to include.
output_modalities: Names of the target modalities.
batch_size: Number of images per batch (default: all).
indexes: Only use the specified image indexes.
shuffle: Shuffle the lists of indexes once at the beginning.
augmentation: Apply augmentation or not (bool).
"""
num_rows = len(images)
if indexes == None:
indexes = list(range(num_rows))
if type(indexes) == int:
indexes = list(range(indexes))
if batch_size == None:
batch_size = len(indexes)
idx = 0
# Prepare empty batch placeholder with named inputs and outputs
input_batch = np.zeros((batch_size,) + shape + (len(images),))
output_batch = np.zeros((batch_size,) + shape + (1,))
# Loop infinitely to keep generating batches
while True:
# Prepare each observation in a batch
for batch_idx in range(batch_size):
# Shuffle the order of images if all indexes have been seen
if idx == 0 and shuffle:
np.random.shuffle(indexes)
current_index = indexes[idx]
# Insert the augmented images into the input batch
img_crop = [images[s][current_index] for s in sequences]
img_crop = np.stack(img_crop, axis=-1)
seg_crop = segmentations[current_index][..., np.newaxis]
if augmentation:
img_crop, seg_crop = augment(img_crop, seg_crop)
input_batch[batch_idx] = img_crop
output_batch[batch_idx] = seg_crop
# Increase the current index and modulo by the number of rows
# so that we stay within bounds
idx = (idx + 1) % len(indexes)
yield input_batch, output_batch
def resample_to_reference(image, ref_img,
interpolator=sitk.sitkNearestNeighbor,
default_pixel_value=0):
resampled_img = sitk.Resample(image, ref_img,
sitk.Transform(),
interpolator, default_pixel_value,
ref_img.GetPixelID())
return resampled_img
def dump_dict_to_yaml(
data: dict,
target_dir: str,
filename: str = "settings") -> None:
""" Writes the given dictionary as a yaml to the target directory.
Parameters:
`data (dict)`: dictionary of data.
`target_dir (str)`: directory where the yaml will be saved.
`` filename (str)`: name of the file without extension
"""
print("\nParameters")
for pair in data.items():
print(f"\t{pair}")
print()
path = f"{target_dir}/{filename}.yml"
print_(f"Wrote yaml to: {path}")
with open(path, 'w') as outfile:
yaml.dump(data, outfile, default_flow_style=False)

View File

@@ -0,0 +1,46 @@
import numpy as np
import SimpleITK as sitk
from helpers import *
def preprocess(imgs: dict, seg: sitk.Image,
shape: tuple, spacing: tuple, to_numpy=True) -> tuple:
# Resample all of the images to the desired voxel spacing
img_r = {k: resample(imgs[k],
min_shape=(s+1 for s in shape),
new_spacing=spacing) for k in imgs}
seg_r = resample(seg,
min_shape=(s+1 for s in shape),
new_spacing=spacing,
method=sitk.sitkNearestNeighbor)
# Center crop one of the input images
ref_seq = [k for k in img_r.keys()][0]
ref_img = center_crop(img_r[ref_seq], shape=shape)
# Then crop the remaining series / segmentation to match the input crop by
# transforming them on the physical space of the cropped series.
img_crop = {k: resample_to_reference(
image=img_r[k],
ref_img=ref_img,
interpolator=sitk.sitkLinear)
for k in img_r}
seg_crop = resample_to_reference(
image=seg_r,
ref_img=ref_img,
interpolator=sitk.sitkNearestNeighbor)
# Return sitk.Image instead of numpy np.ndarray.
if not to_numpy:
return img_crop, seg_crop
img_n = {k: sitk.GetArrayFromImage(img_crop[k]).T for k in img_crop}
seg_n = sitk.GetArrayFromImage(seg_crop).T
seg_n = np.clip(seg_n, 0., 1.)
# Z-score normalize all images
# to do: 2*std
for seq in img_n:
img_n[seq] = (img_n[seq] - np.mean(img_n[seq])) / ( 2* np.std(img_n[seq]))
return img_n, seg_n

114
code/DWI_exp/unet.py Executable file
View File

@@ -0,0 +1,114 @@
from tensorflow.keras import backend as K
from tensorflow.keras import Input, Model
from tensorflow.keras.layers import Conv3D, Activation, Dense, concatenate, add
from tensorflow.keras.layers import Conv3DTranspose, LeakyReLU
from tensorflow.keras import regularizers
from tensorflow.keras.layers import GlobalAveragePooling3D, Reshape, Dense, multiply, Permute
def squeeze_excite_block(input, ratio=8):
''' Create a channel-wise squeeze-excite block
Args:
input: input tensor
filters: number of output filters
Returns: a keras tensor
References
- [Squeeze and Excitation Networks](https://arxiv.org/abs/1709.01507)
'''
init = input
channel_axis = 1 if K.image_data_format() == "channels_first" else -1
filters = init.shape[channel_axis]
se_shape = (1, 1, 1, filters)
se = GlobalAveragePooling3D()(init)
se = Reshape(se_shape)(se)
se = Dense(filters // ratio, activation='relu', kernel_initializer='he_normal', use_bias=False)(se)
se = Dense(filters, activation='sigmoid', kernel_initializer='he_normal', use_bias=False)(se)
if K.image_data_format() == 'channels_first':
se = Permute((4, 1, 2, 3))(se)
x = multiply([init, se])
return x
def build_dual_attention_unet(
input_shape,
l2_regularization = 0.0001,
):
def conv_layer(x, kernel_size, out_filters, strides=(1,1,1)):
x = Conv3D(out_filters, kernel_size,
strides = strides,
padding = 'same',
kernel_regularizer = regularizers.l2(l2_regularization),
kernel_initializer = 'he_normal',
use_bias = False
)(x)
return x
def conv_block(input, out_filters, strides=(1,1,1), with_residual=False, with_se=False, activation='relu'):
# Strided convolution to convsample
x = conv_layer(input, (3,3,3), out_filters, strides)
x = Activation('relu')(x)
# Unstrided convolution
x = conv_layer(x, (3,3,3), out_filters)
# Add a squeeze-excite block
if with_se:
se = squeeze_excite_block(x)
x = add([x, se])
# Add a residual connection using a 1x1x1 convolution with strides
if with_residual:
residual = conv_layer(input, (1,1,1), out_filters, strides)
x = add([x, residual])
if activation == 'leaky':
x = LeakyReLU(alpha=.1)(x)
else:
x = Activation('relu')(x)
# Activate whatever comes out of this
return x
# If we already have only one input, no need to combine anything
inputs = Input(input_shape)
# Downsampling
conv1 = conv_block(inputs, 16)
conv2 = conv_block(conv1, 32, strides=(2,2,1), with_residual=True, with_se=True) #72x72x18
conv3 = conv_block(conv2, 64, strides=(2,2,1), with_residual=True, with_se=True) #36x36x18
conv4 = conv_block(conv3, 128, strides=(2,2,2), with_residual=True, with_se=True) #18x18x9
conv5 = conv_block(conv4, 256, strides=(2,2,2), with_residual=True, with_se=True) #9x9x9
# First upsampling sequence
up1_1 = Conv3DTranspose(128, (3,3,3), strides=(2,2,2), padding='same')(conv5) #18x18x9
up1_2 = Conv3DTranspose(128, (3,3,3), strides=(2,2,2), padding='same')(up1_1) #36x36x18
up1_3 = Conv3DTranspose(128, (3,3,3), strides=(2,2,1), padding='same')(up1_2) #72x72x18
bridge1 = concatenate([conv4, up1_1]) #18x18x9 (128+128=256)
dec_conv_1 = conv_block(bridge1, 128, with_residual=True, with_se=True, activation='leaky') #18x18x9
# Second upsampling sequence
up2_1 = Conv3DTranspose(64, (3,3,3), strides=(2,2,2), padding='same')(dec_conv_1) # 36x36x18
up2_2 = Conv3DTranspose(64, (3,3,3), strides=(2,2,1), padding='same')(up2_1) # 72x72x18
bridge2 = concatenate([conv3, up1_2, up2_1]) # 36x36x18 (64+128+64=256)
dec_conv_2 = conv_block(bridge2, 64, with_residual=True, with_se=True, activation='leaky')
# Final upsampling sequence
up3_1 = Conv3DTranspose(32, (3,3,3), strides=(2,2,1), padding='same')(dec_conv_2) # 72x72x18
bridge3 = concatenate([conv2, up1_3, up2_2, up3_1]) # 72x72x18 (32+128+64+32=256)
dec_conv_3 = conv_block(bridge3, 32, with_residual=True, with_se=True, activation='leaky')
# Last upsampling to make heatmap
up4_1 = Conv3DTranspose(16, (3,3,3), strides=(2,2,1), padding='same')(dec_conv_3) # 72x72x18
dec_conv_4 = conv_block(up4_1, 16, with_residual=False, with_se=True, activation='leaky') #144x144x18 (16)
# Reduce to a single output channel with a 1x1x1 convolution
single_channel = Conv3D(1, (1, 1, 1))(dec_conv_4)
# Apply sigmoid activation to get binary prediction per voxel
act = Activation('sigmoid')(single_channel)
# Model definition
model = Model(inputs=inputs, outputs=act)
return model