commit before migration to habrok
This commit is contained in:
6
code/DWI_exp/__init__.py
Executable file
6
code/DWI_exp/__init__.py
Executable file
@@ -0,0 +1,6 @@
|
||||
print('verkeerde code')
|
||||
from .batchgenerator import *
|
||||
from .callbacks import *
|
||||
from .helpers import *
|
||||
from .preprocessing_function import *
|
||||
from .unet import *
|
64
code/DWI_exp/batchgenerator.py
Executable file
64
code/DWI_exp/batchgenerator.py
Executable file
@@ -0,0 +1,64 @@
|
||||
import math
|
||||
from typing import Callable, Dict, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
from tensorflow.keras.utils import Sequence
|
||||
|
||||
class BatchGenerator(Sequence):
|
||||
def __init__(self,
|
||||
images: Dict[str, List[np.ndarray]],
|
||||
segmentations: List[np.ndarray],
|
||||
sequences: List[str],
|
||||
shape: Tuple[int],
|
||||
indexes: List[int] = None,
|
||||
batch_size: int = 5,
|
||||
shuffle: bool = False,
|
||||
augmentation_function: Optional[Callable\
|
||||
[[np.ndarray, np.ndarray], Tuple[np.ndarray, np.ndarray]]] = None):
|
||||
|
||||
self.images = images
|
||||
self.segmentations = segmentations
|
||||
self.sequences = sequences
|
||||
self.indexes = indexes
|
||||
self.batch_size = batch_size
|
||||
self.shuffle = shuffle
|
||||
self.augmentation = augmentation_function
|
||||
|
||||
self.num_rows = len(images)
|
||||
if self.indexes == None:
|
||||
self.indexes = list(range(self.num_rows))
|
||||
|
||||
if type(indexes) == int:
|
||||
self.indexes = list(range(self.indexes))
|
||||
|
||||
if self.batch_size == None:
|
||||
self.batch_size = len(self.indexes)
|
||||
|
||||
# Stack image index as index 0, stack sequences as index -1
|
||||
self._images = np.stack(
|
||||
[np.stack(self.images[k], axis=0) for k in sequences], axis=-1)
|
||||
self._segmentations = np.stack(segmentations, axis=0)[..., np.newaxis]
|
||||
|
||||
def __len__(self):
|
||||
return math.ceil(len(self.indexes) / self.batch_size)
|
||||
|
||||
def __getitem__(self, batch_idx):
|
||||
# Get image / segmentation indexes for current batch
|
||||
indexes = self.indexes[batch_idx*self.batch_size:(1+batch_idx)*self.batch_size]
|
||||
|
||||
# Get matching images and segmentations
|
||||
images = self._images[indexes]
|
||||
segmentations = self._segmentations[indexes]
|
||||
|
||||
# Apply data augmentation if a function is provided
|
||||
if self.augmentation:
|
||||
for img_idx in range(len(indexes)):
|
||||
images[img_idx], segmentations[img_idx] = self.augmentation(
|
||||
images[img_idx], segmentations[img_idx])
|
||||
return images, segmentations
|
||||
|
||||
def on_epoch_end(self):
|
||||
super().on_epoch_end()
|
||||
|
||||
if self.shuffle:
|
||||
np.random.shuffle(self.indexes)
|
96
code/DWI_exp/callbacks.py
Executable file
96
code/DWI_exp/callbacks.py
Executable file
@@ -0,0 +1,96 @@
|
||||
import numpy as np
|
||||
import SimpleITK as sitk
|
||||
from helpers import *
|
||||
import tensorflow.keras.backend as K
|
||||
import tensorflow as tf
|
||||
from tensorflow.keras.callbacks import Callback
|
||||
from sklearn.metrics import roc_auc_score, roc_curve
|
||||
|
||||
def dice_coef(y_true, y_pred):
|
||||
y_true_f = K.flatten(y_true)
|
||||
y_pred_f = K.round(K.flatten(y_pred))
|
||||
intersection = K.sum(y_true_f * y_pred_f)
|
||||
return (2. * intersection + K.epsilon()) / (
|
||||
K.sum(y_true_f) + K.sum(y_pred_f) + K.epsilon())
|
||||
|
||||
|
||||
# def auc_value(y_true, y_pred):
|
||||
# # print_("start ROC_AUC")
|
||||
# # y_true_old = self.validation_set[1].squeeze()
|
||||
# # y_pred_old = np.around(self.model.predict(self.validation_set[0],batch_size=1).squeeze())
|
||||
|
||||
# y_true_auc = K.flatten(y_true)
|
||||
# # print_(f"The shape of y_true = {np.shape(y_true)}" )
|
||||
|
||||
# y_pred_auc = K.round(K.flatten(y_pred))
|
||||
# # print_(f"The shape of y_pred = {np.shape(y_pred)}" )
|
||||
|
||||
# # fpr, tpr, thresholds = roc_curve(y_true, y_pred, pos_label=1)
|
||||
# auc = roc_auc_score(y_true_auc, y_pred_auc)
|
||||
# print_('AUC:', auc)
|
||||
# print_('AUC shape:', np.shape(auc))
|
||||
# # print_(f'ROC: fpr={fpr} , tpr={tpr}, thresholds={thresholds}')
|
||||
# return auc
|
||||
|
||||
|
||||
class IntermediateImages(Callback):
|
||||
def __init__(self, validation_set, prefix, sequences,
|
||||
num_images=10):
|
||||
self.prefix = prefix
|
||||
self.num_images = num_images
|
||||
self.validation_set = (
|
||||
validation_set[0][:num_images, ...],
|
||||
validation_set[1][:num_images, ...]
|
||||
)
|
||||
|
||||
# Export scan crops and targets once
|
||||
# they don't change during training so we export them only once
|
||||
for i in range(min(self.num_images, self.validation_set[0].shape[0])):
|
||||
for s_idx, s in enumerate(sequences):
|
||||
img_s = sitk.GetImageFromArray(
|
||||
self.validation_set[0][i][..., s_idx].squeeze().T)
|
||||
sitk.WriteImage(img_s, f"{prefix}_{i:03d}_{s}.nii.gz")
|
||||
seg_s = sitk.GetImageFromArray(
|
||||
self.validation_set[1][i].squeeze().T)
|
||||
sitk.WriteImage(seg_s, f"{prefix}_{i:03d}_seg.nii.gz")
|
||||
|
||||
def on_epoch_end(self, epoch, logs={}):
|
||||
# Predict on the validation_set
|
||||
predictions = self.model.predict(self.validation_set, batch_size=1)
|
||||
|
||||
# print_("start ROC_AUC")
|
||||
# y_true_old = self.validation_set[1].squeeze()
|
||||
# y_pred_old = np.around(self.model.predict(self.validation_set[0],batch_size=1).squeeze())
|
||||
|
||||
# y_true = np.array(y_true_old).flatten()
|
||||
# print_(f"The shape of y_true = {np.shape(y_true)}" )
|
||||
|
||||
# y_pred = np.array(y_pred_old).flatten()
|
||||
# print_(f"The shape of y_pred = {np.shape(y_pred)}" )
|
||||
|
||||
# fpr, tpr, thresholds = roc_curve(y_true, y_pred, pos_label=1)
|
||||
# auc = roc_auc_score(y_true, y_pred)
|
||||
# print_('AUC:', auc)
|
||||
# print_(f'ROC: fpr={fpr} , tpr={tpr}, thresholds={thresholds}')
|
||||
|
||||
for i in range(min(self.num_images, self.validation_set[0].shape[0])):
|
||||
prd_s = sitk.GetImageFromArray(predictions[i].squeeze().T)
|
||||
prd_bin_s = sitk.GetImageFromArray(
|
||||
np.around(predictions[i]).astype(np.float32).squeeze().T)
|
||||
sitk.WriteImage(prd_s, f"{self.prefix}_{i:03d}_pred.nii.gz")
|
||||
sitk.WriteImage(prd_bin_s, f"{self.prefix}_{i:03d}_pred_bin.nii.gz")
|
||||
|
||||
# class RocCallback(Callback):
|
||||
# def __init__(self,validation_data):
|
||||
# # self.x = training_data[0]
|
||||
# # self.y = training_data[1]
|
||||
# self.x_val = validation_data[0]
|
||||
# self.y_val = validation_data[1]
|
||||
|
||||
# def on_epoch_end(self, epoch, logs={}):
|
||||
# # y_pred_train = self.model.predict_proba(self.x)
|
||||
# # roc_train = roc_auc_score(self.y, y_pred_train)
|
||||
# y_pred_val = self.model.predict(self.x_val, batch_size=1)
|
||||
# roc_val = roc_auc_score(self.y_val, y_pred_val)
|
||||
# print('\rroc-auc_train: %s - roc-auc_val: %s' % (str(round(roc_val,4))),end=100*' '+'\n')
|
||||
# return
|
239
code/DWI_exp/helpers.py
Executable file
239
code/DWI_exp/helpers.py
Executable file
@@ -0,0 +1,239 @@
|
||||
import os
|
||||
from os import path
|
||||
|
||||
import numpy as np
|
||||
import SimpleITK as sitk
|
||||
from scipy import ndimage
|
||||
import yaml
|
||||
|
||||
def print_(*args, **kwargs):
|
||||
print(*args, **kwargs, flush=True)
|
||||
|
||||
def print_config(config, depth=0):
|
||||
for k in config:
|
||||
print(end=(" "*depth*2) + f"- {k}:")
|
||||
if type(config[k]) == dict:
|
||||
print()
|
||||
print_config(config[k], depth+1)
|
||||
else:
|
||||
print(f"\t{config[k]}")
|
||||
print_(end="")
|
||||
|
||||
def prepare_project_dir(project_dir):
|
||||
"""
|
||||
Prepares the work directory for a project
|
||||
TODO: Create text file containing configuration
|
||||
"""
|
||||
def _makedir(n):
|
||||
os.makedirs(path.join(project_dir, n), exist_ok=True)
|
||||
|
||||
_makedir("models")
|
||||
_makedir("logs")
|
||||
_makedir("data")
|
||||
_makedir("docs")
|
||||
_makedir("samples")
|
||||
_makedir("output")
|
||||
|
||||
|
||||
|
||||
|
||||
def fake_loss(*args, **kwargs):
|
||||
return 0.
|
||||
|
||||
def center_crop(
|
||||
image: sitk.Image, shape: list, offset: list = None
|
||||
) -> sitk.Image:
|
||||
"""Extracts a region of interest from the center of an SITK image.
|
||||
|
||||
Parameters:
|
||||
image: Input image (SITK).
|
||||
shape: The shape of the
|
||||
|
||||
Returns: Cropped image (SITK)
|
||||
"""
|
||||
|
||||
size = image.GetSize()
|
||||
|
||||
# Determine the centroid
|
||||
centroid = [sz / 2 for sz in size]
|
||||
|
||||
# Determine the origin for the bounding box by subtracting half the
|
||||
# shape of the bounding box from each dimension of the centroid.
|
||||
box_start = [int(c - sh / 2) for c, sh in zip(centroid, shape)]
|
||||
|
||||
if offset:
|
||||
box_start = [b - o for b, o in zip(box_start, offset)]
|
||||
|
||||
# Extract the region of provided shape starting from the previously
|
||||
# determined starting pixel.
|
||||
region_extractor = sitk.RegionOfInterestImageFilter()
|
||||
region_extractor.SetSize(shape)
|
||||
region_extractor.SetIndex(box_start)
|
||||
cropped_image = region_extractor.Execute(image)
|
||||
|
||||
return cropped_image
|
||||
|
||||
def augment(img, seg,
|
||||
noise_chance = 0.3,
|
||||
noise_mult_max = 0.001,
|
||||
rotate_chance = 0.2,
|
||||
rotate_max_angle = 30,
|
||||
):
|
||||
#noise_mult_max was initially 0.1
|
||||
if np.random.uniform() < noise_chance:
|
||||
img = augment_noise(img, np.random.uniform(0., noise_mult_max))
|
||||
|
||||
if np.random.uniform() < rotate_chance:
|
||||
img, seg = augment_rotate(img, seg,
|
||||
angle=np.random.uniform(0-rotate_max_angle, rotate_max_angle))
|
||||
img, seg = augment_tilt(img, seg,
|
||||
angle=np.random.uniform(0-rotate_max_angle, rotate_max_angle))
|
||||
|
||||
return img, seg
|
||||
|
||||
def augment_noise(img, multiplier):
|
||||
noise = np.random.standard_normal(img.shape) * multiplier
|
||||
return img + noise
|
||||
|
||||
def augment_rotate(img, seg, angle):
|
||||
img = ndimage.rotate(img, angle, reshape=False, cval=0)
|
||||
seg = ndimage.rotate(seg, angle, reshape=False, order=0)
|
||||
return img, seg
|
||||
|
||||
def augment_tilt(img, seg, angle):
|
||||
img = ndimage.rotate(img, angle, reshape=False, axes=(2,1), cval=0)
|
||||
seg = ndimage.rotate(seg, angle, reshape=False, axes=(2,1), order=0)
|
||||
return img, seg
|
||||
|
||||
def resample(
|
||||
image: sitk.Image, min_shape: list, method=sitk.sitkLinear,
|
||||
new_spacing: list=[1, 1, 3.6]
|
||||
) -> sitk.Image:
|
||||
"""Resamples an image to given target spacing and shape.
|
||||
|
||||
Parameters:
|
||||
image: Input image (SITK).
|
||||
shape: Minimum output shape for the underlying array.
|
||||
method: SimpleITK interpolator to use for resampling.
|
||||
(e.g. sitk.sitkNearestNeighbor, sitk.sitkLinear)
|
||||
new_spacing: The new spacing to resample to.
|
||||
|
||||
Returns:
|
||||
int: Resampled image
|
||||
|
||||
"""
|
||||
|
||||
# Extract size and spacing from the image
|
||||
size = image.GetSize()
|
||||
spacing = image.GetSpacing()
|
||||
|
||||
# Determine how much larger the image will become with the new spacing
|
||||
factor = [sp / new_sp for sp, new_sp in zip(spacing, new_spacing)]
|
||||
|
||||
# Determine the outcome size of the image for each dimension
|
||||
get_size = lambda size, factor, min_shape: max(int(size * factor), min_shape)
|
||||
new_size = [get_size(sz, f, sh) for sz, f, sh in zip(size, factor, min_shape)]
|
||||
|
||||
# Resample the image
|
||||
resampler = sitk.ResampleImageFilter()
|
||||
resampler.SetReferenceImage(image)
|
||||
resampler.SetOutputSpacing(new_spacing)
|
||||
resampler.SetSize(new_size)
|
||||
resampler.SetInterpolator(method)
|
||||
resampled_image = resampler.Execute(image)
|
||||
|
||||
return resampled_image
|
||||
|
||||
# Start training
|
||||
def get_generator(
|
||||
images: dict,
|
||||
segmentations: list,
|
||||
sequences: list,
|
||||
shape: tuple,
|
||||
indexes: list = None,
|
||||
batch_size: int = 5,
|
||||
shuffle: bool = False,
|
||||
augmentation = True
|
||||
):
|
||||
"""
|
||||
Returns a (training) generator for use with model.fit().
|
||||
|
||||
Parameters:
|
||||
input_modalities: List of modalty names to include.
|
||||
output_modalities: Names of the target modalities.
|
||||
batch_size: Number of images per batch (default: all).
|
||||
indexes: Only use the specified image indexes.
|
||||
shuffle: Shuffle the lists of indexes once at the beginning.
|
||||
augmentation: Apply augmentation or not (bool).
|
||||
"""
|
||||
|
||||
num_rows = len(images)
|
||||
|
||||
if indexes == None:
|
||||
indexes = list(range(num_rows))
|
||||
|
||||
if type(indexes) == int:
|
||||
indexes = list(range(indexes))
|
||||
|
||||
if batch_size == None:
|
||||
batch_size = len(indexes)
|
||||
|
||||
idx = 0
|
||||
|
||||
# Prepare empty batch placeholder with named inputs and outputs
|
||||
input_batch = np.zeros((batch_size,) + shape + (len(images),))
|
||||
output_batch = np.zeros((batch_size,) + shape + (1,))
|
||||
|
||||
# Loop infinitely to keep generating batches
|
||||
while True:
|
||||
# Prepare each observation in a batch
|
||||
for batch_idx in range(batch_size):
|
||||
# Shuffle the order of images if all indexes have been seen
|
||||
if idx == 0 and shuffle:
|
||||
np.random.shuffle(indexes)
|
||||
current_index = indexes[idx]
|
||||
|
||||
# Insert the augmented images into the input batch
|
||||
img_crop = [images[s][current_index] for s in sequences]
|
||||
img_crop = np.stack(img_crop, axis=-1)
|
||||
seg_crop = segmentations[current_index][..., np.newaxis]
|
||||
if augmentation:
|
||||
img_crop, seg_crop = augment(img_crop, seg_crop)
|
||||
input_batch[batch_idx] = img_crop
|
||||
output_batch[batch_idx] = seg_crop
|
||||
|
||||
# Increase the current index and modulo by the number of rows
|
||||
# so that we stay within bounds
|
||||
idx = (idx + 1) % len(indexes)
|
||||
|
||||
yield input_batch, output_batch
|
||||
|
||||
def resample_to_reference(image, ref_img,
|
||||
interpolator=sitk.sitkNearestNeighbor,
|
||||
default_pixel_value=0):
|
||||
resampled_img = sitk.Resample(image, ref_img,
|
||||
sitk.Transform(),
|
||||
interpolator, default_pixel_value,
|
||||
ref_img.GetPixelID())
|
||||
return resampled_img
|
||||
|
||||
def dump_dict_to_yaml(
|
||||
data: dict,
|
||||
target_dir: str,
|
||||
filename: str = "settings") -> None:
|
||||
""" Writes the given dictionary as a yaml to the target directory.
|
||||
|
||||
Parameters:
|
||||
`data (dict)`: dictionary of data.
|
||||
`target_dir (str)`: directory where the yaml will be saved.
|
||||
`` filename (str)`: name of the file without extension
|
||||
"""
|
||||
print("\nParameters")
|
||||
for pair in data.items():
|
||||
print(f"\t{pair}")
|
||||
print()
|
||||
|
||||
path = f"{target_dir}/{filename}.yml"
|
||||
print_(f"Wrote yaml to: {path}")
|
||||
with open(path, 'w') as outfile:
|
||||
yaml.dump(data, outfile, default_flow_style=False)
|
46
code/DWI_exp/preprocessing_function.py
Executable file
46
code/DWI_exp/preprocessing_function.py
Executable file
@@ -0,0 +1,46 @@
|
||||
import numpy as np
|
||||
import SimpleITK as sitk
|
||||
|
||||
from helpers import *
|
||||
|
||||
def preprocess(imgs: dict, seg: sitk.Image,
|
||||
shape: tuple, spacing: tuple, to_numpy=True) -> tuple:
|
||||
|
||||
# Resample all of the images to the desired voxel spacing
|
||||
img_r = {k: resample(imgs[k],
|
||||
min_shape=(s+1 for s in shape),
|
||||
new_spacing=spacing) for k in imgs}
|
||||
seg_r = resample(seg,
|
||||
min_shape=(s+1 for s in shape),
|
||||
new_spacing=spacing,
|
||||
method=sitk.sitkNearestNeighbor)
|
||||
|
||||
# Center crop one of the input images
|
||||
ref_seq = [k for k in img_r.keys()][0]
|
||||
ref_img = center_crop(img_r[ref_seq], shape=shape)
|
||||
|
||||
# Then crop the remaining series / segmentation to match the input crop by
|
||||
# transforming them on the physical space of the cropped series.
|
||||
img_crop = {k: resample_to_reference(
|
||||
image=img_r[k],
|
||||
ref_img=ref_img,
|
||||
interpolator=sitk.sitkLinear)
|
||||
for k in img_r}
|
||||
seg_crop = resample_to_reference(
|
||||
image=seg_r,
|
||||
ref_img=ref_img,
|
||||
interpolator=sitk.sitkNearestNeighbor)
|
||||
|
||||
# Return sitk.Image instead of numpy np.ndarray.
|
||||
if not to_numpy:
|
||||
return img_crop, seg_crop
|
||||
|
||||
img_n = {k: sitk.GetArrayFromImage(img_crop[k]).T for k in img_crop}
|
||||
seg_n = sitk.GetArrayFromImage(seg_crop).T
|
||||
seg_n = np.clip(seg_n, 0., 1.)
|
||||
|
||||
# Z-score normalize all images
|
||||
# to do: 2*std
|
||||
for seq in img_n:
|
||||
img_n[seq] = (img_n[seq] - np.mean(img_n[seq])) / ( 2* np.std(img_n[seq]))
|
||||
return img_n, seg_n
|
114
code/DWI_exp/unet.py
Executable file
114
code/DWI_exp/unet.py
Executable file
@@ -0,0 +1,114 @@
|
||||
from tensorflow.keras import backend as K
|
||||
from tensorflow.keras import Input, Model
|
||||
from tensorflow.keras.layers import Conv3D, Activation, Dense, concatenate, add
|
||||
from tensorflow.keras.layers import Conv3DTranspose, LeakyReLU
|
||||
from tensorflow.keras import regularizers
|
||||
from tensorflow.keras.layers import GlobalAveragePooling3D, Reshape, Dense, multiply, Permute
|
||||
|
||||
def squeeze_excite_block(input, ratio=8):
|
||||
''' Create a channel-wise squeeze-excite block
|
||||
Args:
|
||||
input: input tensor
|
||||
filters: number of output filters
|
||||
Returns: a keras tensor
|
||||
References
|
||||
- [Squeeze and Excitation Networks](https://arxiv.org/abs/1709.01507)
|
||||
'''
|
||||
init = input
|
||||
channel_axis = 1 if K.image_data_format() == "channels_first" else -1
|
||||
filters = init.shape[channel_axis]
|
||||
se_shape = (1, 1, 1, filters)
|
||||
|
||||
se = GlobalAveragePooling3D()(init)
|
||||
se = Reshape(se_shape)(se)
|
||||
se = Dense(filters // ratio, activation='relu', kernel_initializer='he_normal', use_bias=False)(se)
|
||||
se = Dense(filters, activation='sigmoid', kernel_initializer='he_normal', use_bias=False)(se)
|
||||
|
||||
if K.image_data_format() == 'channels_first':
|
||||
se = Permute((4, 1, 2, 3))(se)
|
||||
|
||||
x = multiply([init, se])
|
||||
return x
|
||||
|
||||
def build_dual_attention_unet(
|
||||
input_shape,
|
||||
l2_regularization = 0.0001,
|
||||
):
|
||||
|
||||
def conv_layer(x, kernel_size, out_filters, strides=(1,1,1)):
|
||||
x = Conv3D(out_filters, kernel_size,
|
||||
strides = strides,
|
||||
padding = 'same',
|
||||
kernel_regularizer = regularizers.l2(l2_regularization),
|
||||
kernel_initializer = 'he_normal',
|
||||
use_bias = False
|
||||
)(x)
|
||||
return x
|
||||
|
||||
def conv_block(input, out_filters, strides=(1,1,1), with_residual=False, with_se=False, activation='relu'):
|
||||
# Strided convolution to convsample
|
||||
x = conv_layer(input, (3,3,3), out_filters, strides)
|
||||
x = Activation('relu')(x)
|
||||
|
||||
# Unstrided convolution
|
||||
x = conv_layer(x, (3,3,3), out_filters)
|
||||
|
||||
# Add a squeeze-excite block
|
||||
if with_se:
|
||||
se = squeeze_excite_block(x)
|
||||
x = add([x, se])
|
||||
|
||||
# Add a residual connection using a 1x1x1 convolution with strides
|
||||
if with_residual:
|
||||
residual = conv_layer(input, (1,1,1), out_filters, strides)
|
||||
x = add([x, residual])
|
||||
|
||||
if activation == 'leaky':
|
||||
x = LeakyReLU(alpha=.1)(x)
|
||||
else:
|
||||
x = Activation('relu')(x)
|
||||
|
||||
# Activate whatever comes out of this
|
||||
return x
|
||||
|
||||
# If we already have only one input, no need to combine anything
|
||||
inputs = Input(input_shape)
|
||||
|
||||
# Downsampling
|
||||
conv1 = conv_block(inputs, 16)
|
||||
conv2 = conv_block(conv1, 32, strides=(2,2,1), with_residual=True, with_se=True) #72x72x18
|
||||
conv3 = conv_block(conv2, 64, strides=(2,2,1), with_residual=True, with_se=True) #36x36x18
|
||||
conv4 = conv_block(conv3, 128, strides=(2,2,2), with_residual=True, with_se=True) #18x18x9
|
||||
conv5 = conv_block(conv4, 256, strides=(2,2,2), with_residual=True, with_se=True) #9x9x9
|
||||
|
||||
# First upsampling sequence
|
||||
up1_1 = Conv3DTranspose(128, (3,3,3), strides=(2,2,2), padding='same')(conv5) #18x18x9
|
||||
up1_2 = Conv3DTranspose(128, (3,3,3), strides=(2,2,2), padding='same')(up1_1) #36x36x18
|
||||
up1_3 = Conv3DTranspose(128, (3,3,3), strides=(2,2,1), padding='same')(up1_2) #72x72x18
|
||||
bridge1 = concatenate([conv4, up1_1]) #18x18x9 (128+128=256)
|
||||
dec_conv_1 = conv_block(bridge1, 128, with_residual=True, with_se=True, activation='leaky') #18x18x9
|
||||
|
||||
# Second upsampling sequence
|
||||
up2_1 = Conv3DTranspose(64, (3,3,3), strides=(2,2,2), padding='same')(dec_conv_1) # 36x36x18
|
||||
up2_2 = Conv3DTranspose(64, (3,3,3), strides=(2,2,1), padding='same')(up2_1) # 72x72x18
|
||||
bridge2 = concatenate([conv3, up1_2, up2_1]) # 36x36x18 (64+128+64=256)
|
||||
dec_conv_2 = conv_block(bridge2, 64, with_residual=True, with_se=True, activation='leaky')
|
||||
|
||||
# Final upsampling sequence
|
||||
up3_1 = Conv3DTranspose(32, (3,3,3), strides=(2,2,1), padding='same')(dec_conv_2) # 72x72x18
|
||||
bridge3 = concatenate([conv2, up1_3, up2_2, up3_1]) # 72x72x18 (32+128+64+32=256)
|
||||
dec_conv_3 = conv_block(bridge3, 32, with_residual=True, with_se=True, activation='leaky')
|
||||
|
||||
# Last upsampling to make heatmap
|
||||
up4_1 = Conv3DTranspose(16, (3,3,3), strides=(2,2,1), padding='same')(dec_conv_3) # 72x72x18
|
||||
dec_conv_4 = conv_block(up4_1, 16, with_residual=False, with_se=True, activation='leaky') #144x144x18 (16)
|
||||
|
||||
# Reduce to a single output channel with a 1x1x1 convolution
|
||||
single_channel = Conv3D(1, (1, 1, 1))(dec_conv_4)
|
||||
|
||||
# Apply sigmoid activation to get binary prediction per voxel
|
||||
act = Activation('sigmoid')(single_channel)
|
||||
|
||||
# Model definition
|
||||
model = Model(inputs=inputs, outputs=act)
|
||||
return model
|
Reference in New Issue
Block a user